Remove unnecessary bool conversion in basic test.
Add tie breaking test for axis=1.
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 2b4431a..bf5ebe1 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -61,7 +61,7 @@
self._testArg(method, x, axis, expected_values, False, expected_err_re)
def _testBasic(self, dtype):
- x = np.arange(200, dtype=np.float32).astype(np.bool_).astype(dtype)
+ x = np.arange(200, dtype=np.float32).astype(dtype)
np.random.shuffle(x)
# Check that argmin and argmax match numpy along the primary axis
@@ -76,6 +76,12 @@
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
+ # Check that argmin and argmax match numpy along axis=1 for
+ # breaking ties.
+ x = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [0, 1, 0, 1]], dtype=dtype)
+ self._testBothArg(math_ops.argmax, x, 1, x.argmax(axis=1))
+ self._testBothArg(math_ops.argmin, x, 1, x.argmin(axis=1))
+
def _testDim(self, dtype):
shape = (3, 2, 4, 5, 6, 3, 7)
x = np.arange(