blob: 2e7d2eb09db4dc7fc8cdbd3004cae2546661bd8a [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.scatter_nd."""
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
GRADIENT_TESTS_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
def _AsType(v, vtype):
return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
def _FlatInnerDims(tensor, ndims=2):
shape = list(tensor.shape)
return tensor.reshape([
functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)
] + shape[-ndims + 1:])
def _FlatOuterDims(tensor, ndims=2):
shape = list(tensor.shape)
return tensor.reshape(shape[:ndims - 1] + [
functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)
])
def _NumpyScatterNd(ref, indices, updates, op):
ixdim = indices.shape[-1]
num_updates = indices.size // ixdim
total_nd = len(ref.shape)
slice_size = 1
for i in range(ixdim, total_nd):
slice_size *= ref.shape[i]
flat_indices = _FlatInnerDims(indices)
flat_updates = updates.reshape((num_updates, slice_size))
output_flat = _FlatOuterDims(ref, ixdim + 1)
for ix_updates, ix_output in enumerate(flat_indices):
ix_output = tuple(ix_output)
output_flat[ix_output] = op(output_flat[ix_output],
flat_updates[ix_updates])
return output_flat.reshape(ref.shape)
def _NumpyUpdate(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
def _NumpyAdd(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p + u)
def _NumpySub(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p - u)
def _NumpyMul(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p * u)
def _NumpyDiv(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
def _NumpyMin(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, np.minimum)
def _NumpyMax(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, np.maximum)
@test_util.with_eager_op_as_function
class StatefulScatterNdTest(test.TestCase):
def _VariableRankTest(self,
np_scatter,
tf_scatter,
vtype,
itype,
repeat_indices=False):
np.random.seed(8)
ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
with test_util.device(use_gpu=True):
for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
num_updates = indices_shape[0]
ixdim = indices_shape[-1]
indexable_area_shape = ()
for i in range(ixdim):
indexable_area_shape += (ref_shape[i],)
all_indices = [
list(coord)
for coord, _ in np.ndenumerate(
np.empty(indexable_area_shape, vtype))
]
np.random.shuffle(all_indices)
indices = np.array(all_indices[:num_updates])
if num_updates > 1 and repeat_indices:
indices = indices[:num_updates // 2]
for _ in range(num_updates - num_updates // 2):
indices = np.append(
indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
np.random.shuffle(indices)
indices = _AsType(indices[:num_updates], itype)
updates_shape = (num_updates,)
for i in range(ixdim, len(ref_shape)):
updates_shape += (ref_shape[i],)
updates = _AsType(np.random.randn(*(updates_shape)), vtype)
ref = _AsType(np.random.randn(*(ref_shape)), vtype)
# Scatter via numpy
new = ref.copy()
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref_var = variables.VariableV1(ref)
self.evaluate(ref_var.initializer)
self.evaluate(tf_scatter(ref_var, indices, updates))
# Compare
self.assertAllClose(new, self.evaluate(ref_var))
def _VariableRankTests(self, np_scatter, tf_scatter):
for vtype in (np.int32, np.float16, np.float32, np.float64, np.complex64,
np.complex128):
for itype in (np.int32, np.int64):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
def testSimple(self):
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128):
updates = constant_op.constant([9, 10, 11, 12], dtype=dtype)
ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype)
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
with test_util.use_gpu():
self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@test_util.run_in_graph_and_eager_modes
def testString(self):
ref = variables.Variable(["qq", "ww", "ee", "rr", "", "", "", ""])
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant(["aa", "dd", "cc", "bb"])
update = state_ops.scatter_nd_update(ref, indices, updates)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(self.evaluate(update),
[b"qq", b"cc", b"ee", b"dd", b"aa", b"", b"", b"bb"])
def testSimpleResource(self):
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
for dtype in (dtypes.int32, dtypes.float32):
updates = constant_op.constant([9, 10, 11, 12], dtype=dtype)
ref = resource_variable_ops.ResourceVariable([0, 0, 0, 0, 0, 0, 0, 0],
dtype=dtype)
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
with test_util.device(use_gpu=True):
self.evaluate(ref.initializer)
self.evaluate(scatter)
self.assertAllClose(ref, expected)
def testSimple2(self):
indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32)
updates = constant_op.constant([11., 12.], dtype=dtypes.float32)
ref = variables.Variable(
[[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
with self.session():
self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
def testSimple3(self):
indices = constant_op.constant([[1]], dtype=dtypes.int32)
updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
ref = variables.Variable(
[[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
with self.session():
self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
def testVariableRankUpdate(self):
self._VariableRankTests(_NumpyUpdate, state_ops.scatter_nd_update)
def testVariableRankAdd(self):
self._VariableRankTests(_NumpyAdd, state_ops.scatter_nd_add)
def testVariableRankSub(self):
self._VariableRankTests(_NumpySub, state_ops.scatter_nd_sub)
# TODO(ebrevdo): Re-enable when we need ScatterNdMul.
# def testVariableRankMul(self):
# self._VariableRankTests(_NumpyMul, state_ops.scatter_nd_mul)
# TODO(ebrevdo): Re-enable when we need ScatterNdDiv.
# def testVariableRankDiv(self):
# self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div)
def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
for vtype in (np.int32, np.float16, np.float32, np.float64):
for itype in (np.int32, np.int64):
self._VariableRankTest(
np_scatter, tf_scatter, vtype, itype, repeat_indices=True)
def testScatterRepeatIndices(self):
"""This tests scatter_add using indices that repeat."""
self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add)
self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub)
# TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv.
# self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul)
# self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div)
# TODO(simister): Re-enable once binary size increase due to
# extra templating is back under control and this op is re-enabled
# def testBooleanScatterUpdate(self):
# with self.session(use_gpu=False) as session:
# var = tf.Variable([True, False])
# update0 = tf.compat.v1.scatter_nd_update(var, [[1]], [True])
# update1 = tf.compat.v1.scatter_nd_update(
# var, tf.constant(
# [[0]], dtype=tf.int64), [False])
# self.evaluate(var.initializer)
# session.run([update0, update1])
# self.assertAllEqual([False, True], self.evaluate(var))
@test_util.disable_xla("b/205330448")
def testScatterOutOfRangeCpu(self):
# TODO(simister): Re-enable once binary size increase due to
# scatter_nd ops is under control.
# tf.scatter_nd_mul, tf.scatter_nd_div,
for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
state_ops.scatter_nd_update):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with test_util.device(use_gpu=False):
ref = variables.VariableV1(params)
self.evaluate(ref.initializer)
# Indices all in range, no problem.
indices = np.array([[2], [0], [5]])
self.evaluate(op(ref, indices, updates))
# Test some out of range errors.
indices = np.array([[-1], [0], [5]])
with self.assertRaisesOpError(
r"indices\[0\] = \[-1\] does not index into shape \[6\]"):
self.evaluate(op(ref, indices, updates))
indices = np.array([[2], [0], [6]])
with self.assertRaisesOpError(
r"indices\[2\] = \[6\] does not index into shape \[6\]"):
self.evaluate(op(ref, indices, updates))
def testRank3ValidShape(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
self.assertAllEqual(
state_ops.scatter_nd_update(ref, indices,
updates).get_shape().as_list(), shape)
@test_util.disable_xla("b/123337890") # Error messages differ
def testResVarInvalidOutputShape(self):
res = variables.Variable(
initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
dtype=dtypes.float32)
with self.cached_session():
self.evaluate(res.initializer)
with self.assertRaisesOpError("Output must be at least 1-D"):
state_ops.scatter_nd_update(res, [[0]], [0.22]).eval()
def testExtraIndicesDimensions(self):
indices = array_ops.zeros([1, 1, 2], dtypes.int32)
updates = array_ops.zeros([1, 1], dtypes.int32)
shape = np.array([2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
with self.cached_session():
self.evaluate(ref.initializer)
self.assertAllEqual(expected_result, self.evaluate(scatter_update))
def testRank3InvalidShape1(self):
indices = array_ops.zeros([3, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
with self.assertRaisesWithPredicateMatch(
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d,\d\) of indices\[shape="):
state_ops.scatter_nd_update(ref, indices, updates)
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
with self.assertRaisesWithPredicateMatch(
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d,\d\) of input\[shape="):
state_ops.scatter_nd_update(ref, indices, updates)
def testConcurrentUpdates(self):
num_updates = 10000
update_values = np.random.rand(num_updates)
ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32)
updates = constant_op.constant(update_values, dtype=dtypes.float64)
expected_result = np.zeros([2, 2], dtype=np.float64)
expected_result[0, 1] = np.sum(update_values)
scatter = state_ops.scatter_nd_add(ref, indices, updates)
init = variables.global_variables_initializer()
self.evaluate(init)
result = self.evaluate(scatter)
assert np.allclose(result, expected_result)
@test_util.run_in_graph_and_eager_modes
def testMin(self):
variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32))
resource_variable = resource_variable_ops.ResourceVariable(
array_ops.ones([8], dtype=dtypes.int32))
indices = constant_op.constant([4, 3, 1, 7])
updates = constant_op.constant([0, 2, -1, 2], dtype=dtypes.int32)
for ref in (variable, resource_variable):
min_result = state_ops.scatter_min(ref, indices, updates)
self.evaluate(ref.initializer)
expected_result = constant_op.constant([1, -1, 1, 1, 0, 1, 1, 1])
self.assertAllEqual(self.evaluate(min_result), expected_result)
self.assertAllEqual(self.evaluate(ref), expected_result)
@test_util.run_in_graph_and_eager_modes
def testMax(self):
variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32))
resource_variable = resource_variable_ops.ResourceVariable(
array_ops.ones([8], dtype=dtypes.int32))
indices = constant_op.constant([4, 3, 1, 7])
updates = constant_op.constant([0, 2, -1, 2], dtype=dtypes.int32)
for ref in (variable, resource_variable):
max_result = state_ops.scatter_max(ref, indices, updates)
self.evaluate(ref.initializer)
expected_result = constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2])
self.assertAllEqual(self.evaluate(max_result), expected_result)
self.assertAllEqual(self.evaluate(ref), expected_result)
@test_util.run_in_graph_and_eager_modes
def testAdd(self):
variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32))
resource_variable = resource_variable_ops.ResourceVariable(
array_ops.ones([8], dtype=dtypes.int32))
indices = constant_op.constant([4, 3, 1, 7])
updates = constant_op.constant([0, 2, -1, 3], dtype=dtypes.int32)
for ref in (variable, resource_variable):
add_result = state_ops.scatter_add(ref, indices, updates)
self.evaluate(ref.initializer)
expected_result = constant_op.constant([1, 0, 1, 3, 1, 1, 1, 4])
self.assertAllEqual(self.evaluate(add_result), expected_result)
self.assertAllEqual(self.evaluate(ref), expected_result)
@test_util.run_in_graph_and_eager_modes
def testSub(self):
variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32))
resource_variable = resource_variable_ops.ResourceVariable(
array_ops.ones([8], dtype=dtypes.int32))
indices = constant_op.constant([4, 3, 1, 7])
updates = constant_op.constant([0, 2, -1, 2], dtype=dtypes.int32)
for ref in (variable, resource_variable):
sub_result = state_ops.scatter_sub(ref, indices, updates)
self.evaluate(ref.initializer)
expected_result = constant_op.constant([1, 2, 1, -1, 1, 1, 1, -1])
self.assertAllEqual(self.evaluate(sub_result), expected_result)
self.assertAllEqual(self.evaluate(ref), expected_result)
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
def _disabledTestScatterOutOfRangeGpu(self):
if not test.IsBuiltWithCuda():
return
# TODO(simister): Re-enable once binary size increase due to
# scatter_nd ops is under control.
# tf.scatter_nd_mul, tf.scatter_nd_div,
for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
state_ops.scatter_nd_update):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
# With GPU, the code ignores indices that are out of range.
# We don't test the implementation; just test there's no failures.
with self.cached_session(force_gpu=True):
ref = variables.Variable(params)
self.evaluate(ref.initializer)
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
op(ref, indices, updates).eval()
# Indices out of range should not fail.
indices = np.array([-1, 0, 5])
op(ref, indices, updates).eval()
indices = np.array([2, 0, 6])
op(ref, indices, updates).eval()
class StatefulScatterNdDeterminismTest(StatefulScatterNdTest):
def setUp(self):
super().setUp()
config.enable_op_determinism()
def tearDown(self):
super().tearDown()
config.disable_op_determinism()
@test_util.disable_xla("Scatter ND is not deterministic with XLA")
def testDeterminism(self):
ref = variables.Variable(array_ops.zeros([1]))
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
self.evaluate(variables.global_variables_initializer())
val = self.evaluate(state_ops.scatter_nd_update(ref, indices, values))
for _ in range(5):
ref2 = variables.Variable(array_ops.zeros([1]))
self.evaluate(variables.global_variables_initializer())
val2 = self.evaluate(state_ops.scatter_nd_update(ref2, indices, values))
self.assertAllEqual(val, val2)
@test_util.with_eager_op_as_function
class ScatterNdTest(test.TestCase, parameterized.TestCase):
non_aliasing_add_test = False
def scatter_nd(self, indices, updates, shape, input_=None):
del input_ # input_ is not used in scatter_nd
return array_ops.scatter_nd(indices, updates, shape)
@test_util.run_in_graph_and_eager_modes
def testBool(self):
indices = constant_op.constant(
[[4], [3], [1], [7]], dtype=dtypes.int32)
updates = constant_op.constant(
[False, True, False, True], dtype=dtypes.bool)
expected = np.array(
[False, False, False, True, False, False, False, True])
scatter = self.scatter_nd(indices, updates, shape=(8,))
result = self.evaluate(scatter)
self.assertAllEqual(expected, result)
# Same indice is updated twice by same value.
indices = constant_op.constant(
[[4], [3], [3], [7]], dtype=dtypes.int32)
updates = constant_op.constant(
[False, True, True, True], dtype=dtypes.bool)
expected = np.array([
False, False, False, True, False, False, False, True])
scatter = self.scatter_nd(indices, updates, shape=(8,))
result = self.evaluate(scatter)
self.assertAllEqual(expected, result)
@test_util.run_in_graph_and_eager_modes
def testInvalidShape(self):
# TODO(apassos) figure out how to unify these errors
with self.assertRaises(errors.InvalidArgumentError
if context.executing_eagerly() else ValueError):
array_ops.scatter_nd(indices=[0], # this should be indices=[[0]]
updates=[0.0],
shape=[1])
def testString(self):
indices = constant_op.constant([[4], [3], [1], [7]],
dtype=dtypes.int32)
updates = constant_op.constant(["four", "three", "one", "seven"],
dtype=dtypes.string)
expected = np.array([b"", b"one", b"", b"three", b"four",
b"", b"", b"seven"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
with self.cached_session() as sess:
result = self.evaluate(scatter)
self.assertAllEqual(expected, result)
# Same indice is updated twice by same value.
indices = constant_op.constant([[4], [3], [3], [7]],
dtype=dtypes.int32)
updates = constant_op.constant(["a", "b", "b", "c"],
dtype=dtypes.string)
expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
with self.cached_session() as sess:
result = self.evaluate(scatter)
self.assertAllEqual(expected, result)
# Same indice is updated twice by different value.
indices = constant_op.constant([[4], [3], [3], [7]],
dtype=dtypes.int32)
updates = constant_op.constant(["a", "b", "c", "d"],
dtype=dtypes.string)
expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
scatter = self.scatter_nd(indices, updates, shape=(8,))
with self.cached_session() as sess:
result = self.evaluate(scatter)
self.assertTrue(np.array_equal(result, expected[0]) or
np.array_equal(result, expected[1]))
def testRank3ValidShape(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
self.assertAllEqual(
self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)
def testExtraIndicesDimensions(self):
indices = array_ops.zeros([1, 1, 2], dtypes.int32)
updates = array_ops.zeros([1, 1], dtypes.int32)
shape = np.array([2, 2])
scatter = self.scatter_nd(indices, updates, shape)
self.assertAllEqual(scatter.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
self.assertAllEqual(expected_result, self.evaluate(scatter))
def testUndefinedIndicesShape(self):
# Placeholders are only valid in Graph.
with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int32, shape=None)
updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
shape = constant_op.constant([2, 2, 2], dtypes.int32)
self.scatter_nd(indices, updates, shape)
def testUndefinedUpdatesShape(self):
# Placeholders are only valid in Graph.
with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([2, 2, 2], dtypes.int32)
self.scatter_nd(indices, updates, shape)
def testUndefinedOutputShape(self):
# Placeholders are only valid in Graph.
with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
shape = array_ops.placeholder(dtypes.int32, shape=[None])
self.scatter_nd(indices, updates, shape)
def testEmptyOutputShape1(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
with self.assertRaisesWithPredicateMatch(
(errors.InvalidArgumentError, ValueError),
"Indices and updates specified for empty"):
self.scatter_nd(indices, updates, shape)
def testEmptyOutputShape2(self):
with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int32, shape=None)
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
with self.cached_session():
with self.assertRaisesOpError(
"Indices and updates specified for empty (input|output)"):
self.scatter_nd(indices, updates, shape).eval(
feed_dict={
indices: np.zeros([2, 2, 2], dtype=np.int32),
updates: np.zeros([2, 2, 2], dtype=np.int32)
})
def testEmptyOutputShape3(self):
indices = array_ops.zeros([0], dtypes.int32)
updates = array_ops.zeros([0], dtypes.int32)
shape = constant_op.constant([0], dtypes.int32)
scatter = self.scatter_nd(indices, updates, shape)
with self.cached_session():
self.assertEqual(self.evaluate(scatter).size, 0)
def testRank3InvalidShape1(self):
indices = array_ops.zeros([3, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d\,\d\) of indices\[shape="):
self.scatter_nd(indices, updates, shape)
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d\,\d\) of input\[shape="):
self.scatter_nd(indices, updates, shape)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientsRank2ElementUpdate(self, use_tape):
for dtype in GRADIENT_TESTS_DTYPES:
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
updates = constant_op.constant([1, 4], dtype=dtype)
tape.watch(updates)
shape = constant_op.constant([2, 2], dtype=dtypes.int32)
input_ = array_ops.zeros(shape, dtype=dtype)
tape.watch(input_)
outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
updates_grad, input_grad = tape.gradient([outputs], [updates, input_],
[grad_vals])
expected_updates_grad = np.array([1, 4], dtype=dtype.as_numpy_dtype())
expected_input_grad = np.array([[1, 2], [3, 4]],
dtype=dtype.as_numpy_dtype())
self.assertAllEqual(expected_updates_grad, self.evaluate(updates_grad))
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, self.evaluate(input_grad))
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientsRank2SliceUpdate(self, use_tape):
for dtype in GRADIENT_TESTS_DTYPES:
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtype)
tape.watch(updates)
shape = constant_op.constant([2, 2], dtype=dtypes.int32)
input_ = array_ops.zeros(shape, dtype=dtype)
tape.watch(input_)
outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtype)
updates_grad, input_grad = tape.gradient([outputs], [updates, input_],
[grad_vals])
expected_updates_grad = np.array([[1, 2], [3, 4]],
dtype=dtype.as_numpy_dtype())
expected_input_grad = np.array([[3, 4], [1, 2]],
dtype=dtype.as_numpy_dtype())
self.assertAllEqual(expected_updates_grad, self.evaluate(updates_grad))
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, self.evaluate(input_grad))
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientsRank3SliceUpdate(self, use_tape):
for dtype in GRADIENT_TESTS_DTYPES:
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
indices = constant_op.constant([[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
dtype=dtypes.int32)
updates = constant_op.constant([[[5, 7], [2, 4]], [[1, 3], [6, 8]]],
dtype=dtype)
tape.watch(updates)
shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32)
input_ = array_ops.zeros(shape, dtype=dtype)
tape.watch(input_)
outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=dtype)
updates_grad, input_grad = tape.gradient([outputs], [updates, input_],
[grad_vals])
expected_updates_grad = np.array([[[3, 4], [5, 6]], [[1, 2], [7, 8]]],
dtype=dtype.as_numpy_dtype())
expected_input_grad = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=dtype.as_numpy_dtype())
self.assertAllEqual(expected_updates_grad, self.evaluate(updates_grad))
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, self.evaluate(input_grad))
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientsRank7SliceUpdate(self, use_tape):
for dtype in GRADIENT_TESTS_DTYPES:
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
indices = constant_op.constant(
[[[[[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]]]]],
dtype=dtypes.int32)
updates = constant_op.constant(
[[[[[[[5, 6], [2, 4]]]], [[[[1, 3], [6, 8]]]]]]], dtype=dtype)
tape.watch(updates)
shape = constant_op.constant([1, 1, 2, 1, 1, 2, 2], dtype=dtypes.int32)
input_ = array_ops.zeros(shape, dtype=dtype)
tape.watch(input_)
outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant(
[[[[[[[1, 2], [3, 4]]]], [[[[5, 6], [7, 8]]]]]]], dtype=dtype)
updates_grad, input_grad = tape.gradient([outputs], [updates, input_],
[grad_vals])
expected_updates_grad = np.array(
[[[[[[[3, 4], [5, 6]]]], [[[[1, 2], [7, 8]]]]]]],
dtype=dtype.as_numpy_dtype())
expected_input_grad = np.array(
[[[[[[[1, 2], [3, 4]]]], [[[[5, 6], [7, 8]]]]]]],
dtype=dtype.as_numpy_dtype())
self.assertAllEqual(expected_updates_grad, self.evaluate(updates_grad))
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, self.evaluate(input_grad))
def testScatterNdRepeatedIndicesAdd(self):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
val = self.evaluate(self.scatter_nd(indices, values, shape))
self.assertAllClose([np.sum(values)], val)
def testSmokeScatterNdBatch2DSliceDim2(self):
indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
values = array_ops.zeros([3, 5, 7])
shape = [4, 6, 7]
self.evaluate(self.scatter_nd(indices, values, shape))
def testSmokeScatterNdBatch1DSliceDim2(self):
indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
values = array_ops.zeros([0, 7])
shape = [4, 6, 7]
self.evaluate(self.scatter_nd(indices, values, shape))
def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
self.evaluate(self.scatter_nd(indices, values, shape))
def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 2, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
self.evaluate(self.scatter_nd(indices, values, shape))
class ScatterNdNonAliasingAddTest(ScatterNdTest):
non_aliasing_add_test = True
def scatter_nd(self, indices, updates, shape, input_=None):
input_ = (input_ if input_ is not None else array_ops.zeros(
shape, dtype=updates.dtype))
return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates)
def testString(self):
# Not supported yet.
pass
# TODO(testString): Enable this test when the above testString is enabled.
def testStringWithEagerOpAsFunctionEnabled(self):
# Not supported yet.
pass
class ScatterNdDeterminismTest(ScatterNdTest):
def setUp(self):
super().setUp()
config.enable_op_determinism()
def tearDown(self):
super().tearDown()
config.disable_op_determinism()
@test_util.disable_xla("Scatter ND is not deterministic with XLA")
def testDeterminism(self):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
val = self.evaluate(self.scatter_nd(indices, values, shape))
for _ in range(5):
val2 = self.evaluate(self.scatter_nd(indices, values, shape))
self.assertAllEqual(val, val2)
class ScatterNdNonAliasingAddDeterminismTest(ScatterNdDeterminismTest,
ScatterNdNonAliasingAddTest):
pass
class ScatterNdTensorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUpdateAddSub(self):
for dtype in (dtypes.int32, dtypes.float32):
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant([9, 10, 11, 12], dtype=dtype)
t = array_ops.ones([8], dtype=dtype)
assigned = array_ops.tensor_scatter_update(t, indices, updates)
added = array_ops.tensor_scatter_add(t, indices, updates)
subbed = array_ops.tensor_scatter_sub(t, indices, updates)
self.assertAllEqual(assigned,
constant_op.constant([1, 11, 1, 10, 9, 1, 1, 12]))
self.assertAllEqual(added,
constant_op.constant([1, 12, 1, 11, 10, 1, 1, 13]))
self.assertAllEqual(subbed,
constant_op.constant([1, -10, 1, -9, -8, 1, 1, -11]))
def testUpdateAddSubGradients(self):
with self.cached_session():
indices = constant_op.constant([[3], [1]])
updates = constant_op.constant([9, 10], dtype=dtypes.float32)
x = array_ops.ones([4], dtype=dtypes.float32)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_update(x, indices, updates), [x])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_add(x, indices, updates), [x])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_sub(x, indices, updates), [x])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_update(x, indices, updates),
[updates])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_add(x, indices, updates),
[updates])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_sub(x, indices, updates),
[updates])
self.assertAllClose(theoretical, numerical, 5e-4, 5e-4)
@test_util.run_in_graph_and_eager_modes
def testUpdateMinMax(self):
for dtype in (dtypes.int32, dtypes.float32):
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant([0, 2, -1, 2], dtype=dtype)
t = array_ops.ones([8], dtype=dtype)
assigned = array_ops.tensor_scatter_update(t, indices, updates)
min_result = array_ops.tensor_scatter_min(t, indices, updates)
max_result = array_ops.tensor_scatter_max(t, indices, updates)
self.assertAllEqual(assigned,
constant_op.constant([1, -1, 1, 2, 0, 1, 1, 2]))
self.assertAllEqual(min_result,
constant_op.constant([1, -1, 1, 1, 0, 1, 1, 1]))
self.assertAllEqual(max_result,
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2]))
def testUpdateMinMaxGradients(self):
with self.cached_session():
x = array_ops.ones([4], dtype=dtypes.float32)
indices = constant_op.constant([[1], [2], [3], [3]])
updates = constant_op.constant([2.0, 0.5, 1.0, 1.0], dtype=dtypes.float32)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_max(x, indices, updates), [x])
# Numerical gradient doesn't work for degenerate values because the
# derivative is not continuous. The manually entered gradient divides
# the gradient among all contributing elements at the discontinuity.
manual = array_ops.reshape(
array_ops.matrix_diag([1.0, 0.0, 1.0, 0.3333]), (1, 4, 4))
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_min(x, indices, updates), [x])
manual = array_ops.reshape(
array_ops.matrix_diag([1.0, 1.0, 0.0, 0.3333]), (1, 4, 4))
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_max(x, indices, updates),
[updates])
manual = constant_op.constant(
[[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.3333, 0.3333]]],
dtype=dtypes.float32)
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_min(x, indices, updates),
[updates])
manual = constant_op.constant(
[[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.3333, 0.3333]]],
dtype=dtypes.float32)
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
def testTensorScatterUpdateWithForwarding(self):
for dtype in (dtypes.int32, dtypes.float32):
@def_function.function
def _TestFn():
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant([9, 10, 11, 12], dtype=dtype) # pylint: disable=cell-var-from-loop
t = array_ops.ones([8], dtype=dtype) # pylint: disable=cell-var-from-loop
return array_ops.tensor_scatter_update(t, indices, updates)
self.assertAllEqual(_TestFn(), [1, 11, 1, 10, 9, 1, 1, 12])
@test_util.run_in_graph_and_eager_modes
def testTensorScatterUpdateWithStrings(self):
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant(["there", "there", "there", "12"],
dtype=dtypes.string)
tensor = constant_op.constant([
"hello", "hello", "hello", "hello", "hello", "hello", "hello", "hello"
],
dtype=dtypes.string)
updated = array_ops.tensor_scatter_update(tensor, indices, updates)
self.assertAllEqual(
updated,
constant_op.constant([
"hello", "there", "hello", "there", "there", "hello", "hello", "12"
]))
@test_util.run_in_graph_and_eager_modes
def testUpdateRepeatedIndices1D(self):
if test_util.is_gpu_available():
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
a = array_ops.zeros([10, 1])
b = array_ops.tensor_scatter_update(a, [[5], [5]], [[4], [8]])
self.assertAllEqual(
b,
constant_op.constant([[0.], [0.], [0.], [0.], [0.], [8.], [0.], [0.],
[0.], [0.]]))
@test_util.run_in_graph_and_eager_modes
def testUpdateRepeatedIndices2D(self):
if test_util.is_gpu_available():
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
a = array_ops.zeros([10, 10])
b = array_ops.tensor_scatter_update(
a, [[5], [6], [6]],
[math_ops.range(10),
math_ops.range(11, 21),
math_ops.range(10, 20)])
self.assertAllEqual(
b[6],
constant_op.constant(
[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]))
class ScatterNdTensorDeterminismTest(ScatterNdTensorTest):
def setUp(self):
super().setUp()
config.enable_op_determinism()
def tearDown(self):
super().tearDown()
config.disable_op_determinism()
@test_util.disable_xla("Scatter ND is not deterministic with XLA")
def testDeterminism(self):
a = array_ops.zeros([1])
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
val = self.evaluate(array_ops.tensor_scatter_update(a, indices, values))
for _ in range(5):
val2 = self.evaluate(array_ops.tensor_scatter_update(a, indices, values))
self.assertAllEqual(val, val2)
if __name__ == "__main__":
test.main()