Add uint support to tf.sort and fix skipped tests.
For sorting unsigned ints ascending, we cannot negate values, so we need to
subtract from the max unsigned value.
Also fixed a bug for signed ints if the list contains dtype.min. We cannot
simply negate, since that would overflow and result in an incorrect sorted
order. Instead, we need to convert to unsigned.
PiperOrigin-RevId: 378072881
Change-Id: I097e068e86bdbf4f37b60d16a13361f05f0999ec
diff --git a/tensorflow/python/ops/sort_ops.py b/tensorflow/python/ops/sort_ops.py
index 55e353d..13d9477 100644
--- a/tensorflow/python/ops/sort_ops.py
+++ b/tensorflow/python/ops/sort_ops.py
@@ -21,6 +21,7 @@
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@@ -197,6 +198,7 @@
Returns:
The sorted values.
"""
+ # TODO(b/190410105): replace with a proper sort kernel.
k = array_ops.shape(values)[axis]
rank = array_ops.rank(values)
static_rank = values.shape.ndims
@@ -240,11 +242,51 @@
def _ascending_sort(values, axis, return_argsort=False):
- # Negate the values to get the ascending order from descending sort.
- values_or_indices = _descending_sort(-values, axis, return_argsort)
- # If not argsort, negate the values again.
- return values_or_indices if return_argsort else -values_or_indices
+ """Sorts values in ascending order.
+ Args:
+ values: Tensor of numeric values.
+ axis: Index of the axis which values should be sorted along.
+ return_argsort: If False, return the sorted values. If True, return the
+ indices that would sort the values.
+
+ Returns:
+ The sorted values.
+ """
+ # TODO(b/190410105): replace with a proper sort kernel.
+ # If values are integers, we need special handling.
+ dtype = values.dtype
+ if dtype.is_unsigned:
+ # Subtract values from dtype.max to reverse sort order.
+ offset = dtype.max
+ values_or_indices = _descending_sort(offset - values, axis, return_argsort)
+ return values_or_indices if return_argsort else offset - values_or_indices
+
+ elif dtype.is_integer:
+ # Convert to unsigned and subtract from max to avoid signed integer
+ # overflows and properly handle dtype.min. Although more complex and
+ # slightly slower than descend+reverse, this preserves stability.
+ udtype = _MAKE_UNSIGNED[dtype]
+ offset = udtype.max + dtype.min
+ values = offset - math_ops.cast(values, dtype=udtype)
+ values_or_indices = _descending_sort(values, axis, return_argsort)
+ if return_argsort:
+ return values_or_indices
+ return math_ops.cast(offset - values_or_indices, dtype=dtype)
+
+ else:
+ # Otherwise, negate the values and use descending sort.
+ values_or_indices = _descending_sort(-values, axis, return_argsort)
+ # If not argsort, negate the values again.
+ return values_or_indices if return_argsort else -values_or_indices
+
+
+_MAKE_UNSIGNED = {
+ dtypes.int8: dtypes.uint8,
+ dtypes.int16: dtypes.uint16,
+ dtypes.int32: dtypes.uint32,
+ dtypes.int64: dtypes.uint64,
+}
_SORT_IMPL = {
'ASCENDING': _ascending_sort,
diff --git a/tensorflow/python/ops/sort_ops_test.py b/tensorflow/python/ops/sort_ops_test.py
index 17ce604..27251de 100644
--- a/tensorflow/python/ops/sort_ops_test.py
+++ b/tensorflow/python/ops/sort_ops_test.py
@@ -25,7 +25,6 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sort_ops
@@ -34,103 +33,144 @@
class SortTest(test.TestCase):
- @test_util.run_deprecated_v1
+ def random_array(self, shape, dtype):
+ if np.issubdtype(dtype, np.integer):
+ imin = np.iinfo(dtype).min
+ imax = np.iinfo(dtype).max
+ return np.random.randint(imin, imax, shape, dtype)
+ else:
+ return np.random.random(shape).astype(dtype)
+
+ def _test_sort(self, values, axis, direction):
+ expected = np.sort(values, axis=axis)
+ if direction == 'DESCENDING':
+ expected = np.flip(expected, axis=axis)
+ self.assertAllEqual(
+ expected,
+ sort_ops.sort(
+ constant_op.constant(values), axis=axis, direction=direction))
+
def testRandom_lowDimensionality(self):
- self._testRandom_lowDimensionality(negative_axis=False)
+ self._testRandom_lowDimensionality(
+ negative_axis=False, dtype=np.float32, direction='ASCENDING')
- @test_util.run_deprecated_v1
def testRandom_lowDimensionality_negative(self):
- self._testRandom_lowDimensionality(negative_axis=True)
+ self._testRandom_lowDimensionality(
+ negative_axis=True, dtype=np.float32, direction='ASCENDING')
- def _testRandom_lowDimensionality(self, negative_axis):
+ def _testRandom_lowDimensionality(self, negative_axis, dtype, direction):
np.random.seed(42)
for _ in range(20):
rank = np.random.randint(1, 3)
shape = [np.random.randint(0, 20) for _ in range(rank)]
- arr = np.random.random(shape)
+ arr = self.random_array(shape, dtype)
sort_axis = np.random.choice(rank)
if negative_axis:
sort_axis = -1 - sort_axis
with self.cached_session():
- self.assertAllEqual(
- np.sort(arr, axis=sort_axis),
- sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
+ self._test_sort(arr, sort_axis, direction)
- @test_util.run_deprecated_v1
def testRandom_highDimensionality(self):
+ self._testRandom_highDimensionality(np.float32)
+
+ def _testRandom_highDimensionality(self, dtype):
np.random.seed(100)
for _ in range(20):
rank = np.random.randint(5, 15)
shape = [np.random.randint(1, 4) for _ in range(rank)]
- arr = np.random.random(shape)
+ arr = self.random_array(shape, dtype)
sort_axis = np.random.choice(rank)
with self.cached_session():
- self.assertAllEqual(
- np.sort(arr, axis=sort_axis),
- sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
+ self._test_sort(arr, sort_axis, 'ASCENDING')
- @test_util.run_deprecated_v1
+ def testIntArray(self):
+ dtype = np.int64
+ self._testRandom_lowDimensionality(
+ negative_axis=False, dtype=dtype, direction='ASCENDING')
+ self._testRandom_lowDimensionality(
+ negative_axis=False, dtype=dtype, direction='DESCENDING')
+ edges = np.linspace(
+ np.iinfo(dtype).min, np.iinfo(dtype).max, 10, dtype=dtype)
+ self._test_sort(edges, 0, 'ASCENDING')
+ self._test_sort(edges, 0, 'DESCENDING')
+
+ def testUIntArray(self):
+ dtype = np.uint64
+ self._testRandom_lowDimensionality(
+ negative_axis=False, dtype=dtype, direction='ASCENDING')
+ self._testRandom_lowDimensionality(
+ negative_axis=False, dtype=dtype, direction='DESCENDING')
+ edges = np.linspace(
+ np.iinfo(dtype).min, np.iinfo(dtype).max, 10, dtype=dtype)
+ self._test_sort(edges, 0, 'ASCENDING')
+ self._test_sort(edges, 0, 'DESCENDING')
+
def testScalar(self):
# Create an empty scalar where the static shape is unknown.
zeros_length_1 = array_ops.zeros(
random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32),
dtype=dtypes.int32)
scalar = array_ops.zeros(zeros_length_1)
-
- sort = sort_ops.sort(scalar)
with self.cached_session():
- with self.assertRaises(errors.InvalidArgumentError):
- sort.eval()
+ with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
+ 'out of bounds'):
+ self.evaluate(sort_ops.sort(scalar))
- @test_util.run_deprecated_v1
def testNegativeOutOfBounds_staticShape(self):
arr = constant_op.constant([3, 4, 5])
- with self.assertRaises(ValueError):
- sort_ops.sort(arr, axis=-4)
+ with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
+ 'slice index .* out of bounds'):
+ self.evaluate(sort_ops.sort(arr, axis=-4))
- @test_util.run_deprecated_v1
def testDescending(self):
arr = np.random.random((10, 5, 5))
with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=0)[::-1],
sort_ops.sort(
- constant_op.constant(arr), axis=0, direction='DESCENDING').eval())
+ constant_op.constant(arr), axis=0, direction='DESCENDING'))
- @test_util.run_deprecated_v1
def testSort_staticallyKnownRank_constantTransposition(self):
- # The transposition array should be a constant if the rank of "values" is
- # statically known.
- tensor = random_ops.random_uniform(
- # Rank is statically known to be 5, but the dimension lengths are not
- # known.
- random_ops.random_uniform(
- shape=(5,), minval=0, maxval=10, dtype=dtypes.int32))
- sort_ops.sort(tensor, axis=1)
- transposition = (
- ops.get_default_graph().get_tensor_by_name('sort/transposition:0'))
- self.assertFalse(tensor_util.constant_value(transposition) is None)
- self.assertAllEqual(
- # Swaps "1" and "4" to put "1" at the end.
- tensor_util.constant_value(transposition),
- [0, 4, 2, 3, 1])
+ with ops.Graph().as_default():
+ # The transposition array should be a constant if the rank of "values" is
+ # statically known.
+ tensor = random_ops.random_uniform(
+ # Rank is statically known to be 5, but the dimension lengths are not
+ # known.
+ random_ops.random_uniform(
+ shape=(5,), minval=0, maxval=10, dtype=dtypes.int32))
+ sort_ops.sort(tensor, axis=1)
+ transposition = (
+ ops.get_default_graph().get_tensor_by_name('sort/transposition:0'))
+ self.assertIsNot(tensor_util.constant_value(transposition), None)
+ self.assertAllEqual(
+ # Swaps "1" and "4" to put "1" at the end.
+ tensor_util.constant_value(transposition),
+ [0, 4, 2, 3, 1])
- @test_util.run_deprecated_v1
def testArgsort_1d(self):
arr = np.random.random(42)
with self.cached_session():
self.assertAllEqual(
- np.sort(arr),
- array_ops.gather(arr, sort_ops.argsort(arr)).eval())
+ np.sort(arr), array_ops.gather(arr, sort_ops.argsort(arr)))
- @test_util.run_deprecated_v1
+ def testArgsortStable(self):
+ arr = constant_op.constant([1, 5, 2, 2, 3], dtype=dtypes.int32)
+ ascending = [0, 2, 3, 4, 1]
+ descending = [1, 4, 2, 3, 0]
+ with self.cached_session():
+ self.assertAllEqual(
+ sort_ops.argsort(arr, direction='ASCENDING', stable=True), ascending)
+ self.assertAllEqual(
+ sort_ops.argsort(arr, direction='DESCENDING', stable=True),
+ descending)
+
def testArgsort(self):
arr = np.random.random((5, 6, 7, 8))
for axis in range(4):
with self.cached_session():
self.assertAllEqual(
- np.argsort(arr, axis=axis),
- sort_ops.argsort(arr, axis=axis).eval())
+ np.argsort(arr, axis=axis), sort_ops.argsort(arr, axis=axis))
if __name__ == '__main__':