Fixing run_v1_decorator for gather_op_test.py. Also moving it to an array_ops folder.
PiperOrigin-RevId: 324302642
Change-Id: I85d54f334537a25a4c7f6d6eaeb17721cce39e25
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index 948567e..e9e6a93 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -78,10 +78,11 @@
}
}
+ int64 min_params_dim = axis < 0 ? -axis : axis + 1;
OP_REQUIRES(
- c, axis >= -params.dims() && axis < params.dims(),
- errors::InvalidArgument("Expected axis in the range [", -params.dims(),
- ", ", params.dims(), "), but got ", axis));
+ c, params.dims() >= min_params_dim,
+ errors::InvalidArgument("Shape must be at least rank ", min_params_dim,
+ " but is rank ", params.dims()));
if (axis < 0) {
axis = params.dims() + axis;
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 1e093af..8504052 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2036,20 +2036,6 @@
)
cuda_py_test(
- name = "gather_op_test",
- size = "medium",
- srcs = ["gather_op_test.py"],
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:gradients",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-cuda_py_test(
name = "gradient_correctness_test",
size = "small",
srcs = ["gradient_correctness_test.py"],
diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD
index df48258..bc448f3 100644
--- a/tensorflow/python/kernel_tests/array_ops/BUILD
+++ b/tensorflow/python/kernel_tests/array_ops/BUILD
@@ -46,3 +46,17 @@
"//third_party/py/numpy",
],
)
+
+cuda_py_test(
+ name = "gather_op_test",
+ size = "medium",
+ srcs = ["gather_op_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/array_ops/gather_op_test.py
similarity index 74%
rename from tensorflow/python/kernel_tests/gather_op_test.py
rename to tensorflow/python/kernel_tests/array_ops/gather_op_test.py
index 0f59d10..d553b29 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/gather_op_test.py
@@ -107,18 +107,20 @@
expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
self.assertEqual(expected_shape, gather_t.get_shape())
- @test_util.run_deprecated_v1
def testHigherRank(self):
- # We check that scalar and empty indices shapes work as well
- shape = (2, 1, 3, 2)
- for indices_shape in (), (0,), (2, 0), (2, 3):
- for dtype in _TEST_TYPES:
- for axis in range(len(shape)):
- params = self._buildParams(np.random.randn(*shape), dtype)
- indices = np.random.randint(shape[axis], size=indices_shape)
- with self.subTest(indices_shape=indices_shape, dtype=dtype, axis=axis,
- indices=indices):
- with self.cached_session(use_gpu=True) as sess:
+ with ops.Graph().as_default():
+ # We check that scalar and empty indices shapes work as well
+ shape = (2, 1, 3, 2)
+ for indices_shape in (), (0,), (2, 0), (2, 3):
+ for dtype in _TEST_TYPES:
+ for axis in range(len(shape)):
+ params = self._buildParams(np.random.randn(*shape), dtype)
+ indices = np.random.randint(shape[axis], size=indices_shape)
+ with self.subTest(
+ indices_shape=indices_shape,
+ dtype=dtype,
+ axis=axis,
+ indices=indices):
tf_params = constant_op.constant(params)
tf_indices = constant_op.constant(indices)
# Check that both positive and negative indices for axis work.
@@ -127,7 +129,7 @@
gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
gather_negative_axis = array_ops.gather(
tf_params, tf_indices, axis=tf_negative_axis)
- gather_value, gather_negative_axis_value = sess.run(
+ gather_value, gather_negative_axis_value = self.evaluate(
[gather, gather_negative_axis])
gather_np = np.take(params, indices, axis)
self.assertAllEqual(gather_np, gather_value)
@@ -144,10 +146,10 @@
gather_grad -= 1j * gather_grad
params_grad, indices_grad, axis_grad = gradients_impl.gradients(
gather, [tf_params, tf_indices, tf_axis], gather_grad)
- self.assertEqual(indices_grad, None)
- self.assertEqual(axis_grad, None)
+ self.assertIsNone(indices_grad)
+ self.assertIsNone(axis_grad)
if dtype.is_integer:
- self.assertEqual(params_grad, None)
+ self.assertIsNone(params_grad)
continue
# For axis 0, we are able to create an efficient IndexedSlices for
# the gradient.
@@ -171,47 +173,113 @@
atol=2e-6,
rtol=2e-6)
- @test_util.run_deprecated_v1
+ def testHigherRankGradientTape(self):
+ # We check that scalar and empty indices shapes work as well
+ shape = (2, 1, 3, 2)
+ for indices_shape in (), (0,), (2, 0), (2, 3):
+ for dtype in _TEST_TYPES:
+ for axis in range(len(shape)):
+ params = self._buildParams(np.random.randn(*shape), dtype)
+ indices = np.random.randint(shape[axis], size=indices_shape)
+ with self.subTest(
+ indices_shape=indices_shape,
+ dtype=dtype,
+ axis=axis,
+ indices=indices):
+ with backprop.GradientTape() as tape:
+ tf_params = constant_op.constant(params)
+ tf_indices = constant_op.constant(indices)
+ # Check that both positive and negative indices for axis work.
+ tf_axis = constant_op.constant(axis)
+ tape.watch(tf_params)
+ tape.watch(tf_indices)
+ tape.watch(tf_axis)
+ tf_negative_axis = constant_op.constant(-len(shape) + axis)
+ gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
+ gather_negative_axis = array_ops.gather(
+ tf_params, tf_indices, axis=tf_negative_axis)
+ gather_value, gather_negative_axis_value = self.evaluate(
+ [gather, gather_negative_axis])
+ gather_np = np.take(params, indices, axis)
+ self.assertAllEqual(gather_np, gather_value)
+ self.assertAllEqual(gather_np, gather_negative_axis_value)
+ expected_shape = (
+ params.shape[:axis] + indices.shape + params.shape[axis + 1:])
+ self.assertEqual(expected_shape, gather.shape)
+ self.assertEqual(expected_shape, gather_negative_axis.shape)
+
+ # Test gradients
+ gather_grad = np.random.randn(
+ *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
+ if dtype.is_complex:
+ gather_grad -= 1j * gather_grad
+ params_grad, indices_grad, axis_grad = tape.gradient(
+ gather, [tf_params, tf_indices, tf_axis], gather_grad)
+ self.assertIsNone(indices_grad)
+ self.assertIsNone(axis_grad)
+ if dtype.is_integer:
+ self.assertIsNone(params_grad)
+ continue
+ # For axis 0, we are able to create an efficient IndexedSlices for
+ # the gradient.
+ if axis == 0:
+ self.assertEqual(type(params_grad), ops.IndexedSlices)
+ params_grad = ops.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
+ outer_dims = axis
+ inner_dims = len(shape) - axis - 1
+ gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) +
+ shape[axis + 1:])
+ for source_index, dest_index in enumerate(indices.flat):
+ dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
+ (slice(None),) * inner_dims)
+ source_slice = ((slice(None),) * outer_dims + (source_index,) +
+ (slice(None),) * inner_dims)
+ correct_params_grad[dest_slice] += gather_grad[source_slice]
+ self.assertAllClose(
+ correct_params_grad,
+ self.evaluate(params_grad),
+ atol=2e-6,
+ rtol=2e-6)
+
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
- with self.cached_session():
- self.assertAllEqual([b"qwer", b"uiop"],
- array_ops.gather(params, 1, axis=0).eval())
- self.assertAllEqual([b"asdf", b"qwer"],
- array_ops.gather(params, 0, axis=1).eval())
+ self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0))
+ self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1))
- @test_util.run_deprecated_v1
def testUInt32AndUInt64(self):
for unsigned_type in (dtypes.uint32, dtypes.uint64):
with self.subTest(unsigned_type=unsigned_type):
params = self._buildParams(
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
with self.cached_session():
- self.assertAllEqual([7, 8, 9],
- array_ops.gather(params, 1, axis=0).eval())
- self.assertAllEqual([1, 7],
- array_ops.gather(params, 0, axis=1).eval())
+ self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0))
+ self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1))
- @test_util.run_deprecated_v1
def testUnknownIndices(self):
- params = constant_op.constant([[0, 1, 2]])
- indices = array_ops.placeholder(dtypes.int32)
- gather_t = array_ops.gather(params, indices)
- self.assertEqual(None, gather_t.get_shape())
+ # This test is purely a test for placeholder inputs which is only applicable
+ # in graph mode.
+ with ops.Graph().as_default():
+ params = constant_op.constant([[0, 1, 2]])
+ indices = array_ops.placeholder(dtypes.int32)
+ gather_t = array_ops.gather(params, indices)
+ self.assertEqual(None, gather_t.get_shape())
- @test_util.run_deprecated_v1
def testUnknownAxis(self):
- params = constant_op.constant([[0, 1, 2]])
- indices = constant_op.constant([[0, 0], [0, 0]])
- axis = array_ops.placeholder(dtypes.int32)
- gather_t = array_ops.gather(params, indices, axis=axis)
- # Rank 2 params with rank 2 indices results in a rank 3 shape.
- self.assertEqual([None, None, None], gather_t.shape.as_list())
+ # This test is purely a test for placeholder inputs which is only applicable
+ # in graph mode.
+ with ops.Graph().as_default():
+ params = constant_op.constant([[0, 1, 2]])
+ indices = constant_op.constant([[0, 0], [0, 0]])
+ axis = array_ops.placeholder(dtypes.int32)
+ gather_t = array_ops.gather(params, indices, axis=axis)
+ # Rank 2 params with rank 2 indices results in a rank 3 shape.
+ self.assertEqual([None, None, None], gather_t.shape.as_list())
- # If indices is also unknown the result rank is unknown.
- indices = array_ops.placeholder(dtypes.int32)
- gather_t = array_ops.gather(params, indices, axis=axis)
- self.assertEqual(None, gather_t.shape)
+ # If indices is also unknown the result rank is unknown.
+ indices = array_ops.placeholder(dtypes.int32)
+ gather_t = array_ops.gather(params, indices, axis=axis)
+ self.assertEqual(None, gather_t.shape)
def testBadIndicesType(self):
with self.assertRaisesRegex(
@@ -243,45 +311,36 @@
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
array_ops.gather(params, [[7]], axis=1).eval()
- @test_util.run_deprecated_v1
def testBadAxis(self):
- with self.session(use_gpu=True):
- params = [0, 1, 2]
- params_ph = array_ops.placeholder(dtypes.int32)
- indices = 0
- for bad_axis in (1, 2, -2):
- # Shape inference can validate axis for known params rank.
- with self.subTest(bad_axis=bad_axis):
- with self.assertRaisesWithPredicateMatch(
- ValueError, "Shape must be at least rank . but is rank 1"):
- array_ops.gather(params, indices, axis=bad_axis)
- # If params rank is unknown, an op error occurs.
- with self.assertRaisesOpError(
- r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis):
- array_ops.gather(params_ph, indices, axis=bad_axis).eval(
- feed_dict={params_ph: params})
+ params = [0, 1, 2]
+ indices = 0
+ for bad_axis in (1, 2, -2):
+ # Shape inference can validate axis for known params rank.
+ with self.subTest(bad_axis=bad_axis):
+ with self.assertRaisesRegex(
+ (ValueError, errors.InvalidArgumentError),
+ "Shape must be at least rank .* but is rank 1"):
+ array_ops.gather(params, indices, axis=bad_axis)
- @test_util.run_deprecated_v1
def testEmptySlices(self):
- with self.session(use_gpu=True):
- for dtype in _TEST_TYPES:
- for itype in np.int32, np.int64:
- # Leading axis gather.
- with self.subTest(dtype=dtype, itype=itype):
- params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
- indices = np.array([3, 4], dtype=itype)
- gather = array_ops.gather(params, indices, axis=0)
- self.assertAllEqual(gather, np.zeros((2, 0, 0)))
+ for dtype in _TEST_TYPES:
+ for itype in np.int32, np.int64:
+ # Leading axis gather.
+ with self.subTest(dtype=dtype, itype=itype):
+ params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
+ indices = np.array([3, 4], dtype=itype)
+ gather = array_ops.gather(params, indices, axis=0)
+ self.assertAllEqual(gather, np.zeros((2, 0, 0)))
- # Middle axis gather.
- params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
- gather = array_ops.gather(params, indices, axis=1)
- self.assertAllEqual(gather, np.zeros((0, 2, 0)))
+ # Middle axis gather.
+ params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
+ gather = array_ops.gather(params, indices, axis=1)
+ self.assertAllEqual(gather, np.zeros((0, 2, 0)))
- # Trailing axis gather.
- params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
- gather = array_ops.gather(params, indices, axis=2)
- self.assertAllEqual(gather, np.zeros((0, 0, 2)))
+ # Trailing axis gather.
+ params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
+ gather = array_ops.gather(params, indices, axis=2)
+ self.assertAllEqual(gather, np.zeros((0, 0, 2)))
@parameterized.parameters([
# batch_dims=0 (equivalent to tf.gather)
@@ -385,20 +444,13 @@
self.assertAllEqual(expected, result)
# Test the gradients shape.
- if context.executing_eagerly():
- with backprop.GradientTape() as tape:
- zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
- tape.watch(zeros)
- values = zeros * 2 + zeros
- result = array_ops.gather(
- values, indices, axis=axis, batch_dims=batch_dims)
- gradients = tape.gradient(result, zeros)
- else:
+ with backprop.GradientTape() as tape:
zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
+ tape.watch(zeros)
values = zeros * 2 + zeros
result = array_ops.gather(
values, indices, axis=axis, batch_dims=batch_dims)
- gradients = gradients_impl.gradients(result, [zeros])[0]
+ gradients = tape.gradient(result, zeros)
self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients))