Improve consistency of error messages between TF and XLA "ScatterOps".
PiperOrigin-RevId: 408926419
Change-Id: Ib9efad3adac6af5fbe9c7058b63d3460e95bf01b
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 6938e13..7b82842 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -202,6 +202,7 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:scatter_nd_util",
"//tensorflow/core/kernels:stateful_random_ops_header",
"//tensorflow/core/kernels:stateless_random_ops_v2_header",
"//tensorflow/core/tpu:tpu_defs",
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index d26df34..b60f374 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -24,6 +24,7 @@
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/scatter_nd_util.h"
namespace tensorflow {
namespace {
@@ -164,6 +165,12 @@
xla::XlaOp var_value;
OP_REQUIRES_OK(
context, context->ReadVariableInput(0, dtype, &var_shape, &var_value));
+ // This check is only required for ScatterNdOps.
+ if (indices_are_vectors_) {
+ OP_REQUIRES_OK(context, ValidateScatterNdUpdateShape(
+ var_shape, context->InputShape(1),
+ context->InputShape(2)));
+ }
const xla::XlaOp indices = context->Input(1);
const xla::XlaOp updates = context->Input(2);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index bff380d..cf8d079 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5171,6 +5171,15 @@
],
)
+cc_library(
+ name = "scatter_nd_util",
+ srcs = ["scatter_nd_util.cc"],
+ hdrs = ["scatter_nd_util.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+)
+
tf_kernel_library(
name = "scatter_nd_op",
srcs = [
@@ -5196,6 +5205,7 @@
"scatter_nd_op_gpu.cu.cc",
],
deps = STATE_DEPS + [
+ ":scatter_nd_util",
":dense_update_functor",
":training_op_helpers",
":variable_ops",
@@ -6241,6 +6251,7 @@
"scan_ops.h",
"scatter_functor.h",
"scatter_nd_op.h",
+ "scatter_nd_util.h",
"segment_reduction_ops.h",
"segment_reduction_ops_impl.h",
"softplus_op.h",
@@ -6521,6 +6532,7 @@
"scatter_nd_op_cpu_impl_5.cc",
"scatter_nd_op_cpu_impl_6.cc",
"scatter_nd_op_cpu_impl_7.cc",
+ "scatter_nd_util.cc",
"segment_reduction_ops_impl_1.cc",
"segment_reduction_ops_impl_2.cc",
"segment_reduction_ops_impl_3.cc",
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 82f2421..e24eed3 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -21,8 +21,6 @@
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#include "tensorflow/core/kernels/scatter_nd_op.h"
-
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -32,6 +30,8 @@
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
+#include "tensorflow/core/kernels/scatter_nd_op.h"
+#include "tensorflow/core/kernels/scatter_nd_util.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -40,7 +40,6 @@
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/util.h"
-
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -773,47 +772,6 @@
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
-// Check whether updates.shape = indices.shape[:batch_dim] +
-// params_shape[slice_dim:]
-Status ValidateUpdateShape(const TensorShape& params_shape,
- const Tensor& indices, const Tensor& updates) {
- const int64_t slice_dim =
- (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
- const int64_t batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
-
- auto shape_err_prefix = [&]() {
- return errors::InvalidArgument(
- "Dimensions [0,", batch_dim,
- ") of indices[shape=", indices.shape().DebugString(),
- "] must match dimensions [0,", batch_dim,
- ") of updates[shape=", updates.shape().DebugString(), "]");
- };
- auto shape_err_suffix = [&]() {
- return errors::InvalidArgument(
- "Dimensions [", slice_dim, ",", params_shape.dims(),
- ") of input[shape=", params_shape.DebugString(),
- "] must match dimensions [", slice_dim, ",", updates.dims(),
- ") of updates[shape=", updates.shape().DebugString(), "]");
- };
-
- if (updates.dims() < batch_dim) return shape_err_prefix();
- if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
- return shape_err_suffix();
- }
- if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
- return shape_err_suffix();
- }
- for (int d = 0; d < batch_dim; ++d) {
- if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix();
- }
- for (int d = 0; d < updates.dims() - batch_dim; ++d) {
- if (updates.dim_size(d + batch_dim) !=
- params_shape.dim_size(d + slice_dim)) {
- return shape_err_suffix();
- }
- }
- return Status::OK();
-}
template <typename Index>
Status PrepareAndValidateInputs(const TensorShape& params_shape,
@@ -842,7 +800,8 @@
"] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
"shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
}
- TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
+ TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(),
+ updates.shape()));
// Check that we have enough index space
const int64_t N_big = indices.NumElements();
diff --git a/tensorflow/core/kernels/scatter_nd_util.cc b/tensorflow/core/kernels/scatter_nd_util.cc
new file mode 100644
index 0000000..a0ad236
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/scatter_nd_util.h"
+
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+Status ValidateScatterNdUpdateShape(const TensorShape& params_shape,
+ const TensorShape& indices_shape,
+ const TensorShape& updates_shape) {
+ const int64_t slice_dim =
+ (indices_shape.dims() > 1)
+ ? indices_shape.dim_size(indices_shape.dims() - 1)
+ : 1;
+ const int64_t batch_dim =
+ (indices_shape.dims() > 1) ? indices_shape.dims() - 1 : 1;
+
+ auto shape_err_prefix = [&]() {
+ return errors::InvalidArgument(
+ "Dimensions [0,", batch_dim,
+ ") of indices[shape=", indices_shape.DebugString(),
+ "] must match dimensions [0,", batch_dim,
+ ") of updates[shape=", updates_shape.DebugString(), "]");
+ };
+ auto shape_err_suffix = [&]() {
+ return errors::InvalidArgument(
+ "Dimensions [", slice_dim, ",", params_shape.dims(),
+ ") of input[shape=", params_shape.DebugString(),
+ "] must match dimensions [", slice_dim, ",", updates_shape.dims(),
+ ") of updates[shape=", updates_shape.DebugString(), "]");
+ };
+
+ if (updates_shape.dims() < batch_dim) return shape_err_prefix();
+ if (params_shape.dims() < slice_dim + (updates_shape.dims() - batch_dim)) {
+ return shape_err_suffix();
+ }
+ if (updates_shape.dims() != batch_dim + params_shape.dims() - slice_dim) {
+ return shape_err_suffix();
+ }
+ for (int d = 0; d < batch_dim; ++d) {
+ if (updates_shape.dim_size(d) != indices_shape.dim_size(d))
+ return shape_err_prefix();
+ }
+ for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
+ if (updates_shape.dim_size(d + batch_dim) !=
+ params_shape.dim_size(d + slice_dim)) {
+ return shape_err_suffix();
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_nd_util.h b/tensorflow/core/kernels/scatter_nd_util.h
new file mode 100644
index 0000000..c2285f8
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_util.h
@@ -0,0 +1,30 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+// Validates the input shapes for the ScatterNdUpdateOp<scatter_nd_op::UpdateOp>
+Status ValidateScatterNdUpdateShape(const TensorShape& params_shape,
+ const TensorShape& indices_shape,
+ const TensorShape& updates_shape);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
diff --git a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
index c5452ed..2e7d2eb 100644
--- a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
@@ -102,6 +102,7 @@
return _NumpyScatterNd(ref, indices, updates, np.maximum)
+@test_util.with_eager_op_as_function
class StatefulScatterNdTest(test.TestCase):
def _VariableRankTest(self,
@@ -271,6 +272,7 @@
# session.run([update0, update1])
# self.assertAllEqual([False, True], self.evaluate(var))
+ @test_util.disable_xla("b/205330448")
def testScatterOutOfRangeCpu(self):
# TODO(simister): Re-enable once binary size increase due to
# scatter_nd ops is under control.
@@ -483,6 +485,7 @@
self.assertAllEqual(val, val2)
+@test_util.with_eager_op_as_function
class ScatterNdTest(test.TestCase, parameterized.TestCase):
non_aliasing_add_test = False
@@ -796,6 +799,11 @@
# Not supported yet.
pass
+ # TODO(testString): Enable this test when the above testString is enabled.
+ def testStringWithEagerOpAsFunctionEnabled(self):
+ # Not supported yet.
+ pass
+
class ScatterNdDeterminismTest(ScatterNdTest):