Improve consistency of error messages between TF and XLA "ScatterOps".

PiperOrigin-RevId: 408926419
Change-Id: Ib9efad3adac6af5fbe9c7058b63d3460e95bf01b
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 6938e13..7b82842 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -202,6 +202,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/kernels:scatter_nd_util",
         "//tensorflow/core/kernels:stateful_random_ops_header",
         "//tensorflow/core/kernels:stateless_random_ops_v2_header",
         "//tensorflow/core/tpu:tpu_defs",
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index d26df34..b60f374 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/scatter_nd_util.h"
 
 namespace tensorflow {
 namespace {
@@ -164,6 +165,12 @@
     xla::XlaOp var_value;
     OP_REQUIRES_OK(
         context, context->ReadVariableInput(0, dtype, &var_shape, &var_value));
+    // This check is only required for ScatterNdOps.
+    if (indices_are_vectors_) {
+      OP_REQUIRES_OK(context, ValidateScatterNdUpdateShape(
+                                  var_shape, context->InputShape(1),
+                                  context->InputShape(2)));
+    }
 
     const xla::XlaOp indices = context->Input(1);
     const xla::XlaOp updates = context->Input(2);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index bff380d..cf8d079 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5171,6 +5171,15 @@
     ],
 )
 
+cc_library(
+    name = "scatter_nd_util",
+    srcs = ["scatter_nd_util.cc"],
+    hdrs = ["scatter_nd_util.h"],
+    deps = [
+        "//tensorflow/core:framework",
+    ],
+)
+
 tf_kernel_library(
     name = "scatter_nd_op",
     srcs = [
@@ -5196,6 +5205,7 @@
         "scatter_nd_op_gpu.cu.cc",
     ],
     deps = STATE_DEPS + [
+        ":scatter_nd_util",
         ":dense_update_functor",
         ":training_op_helpers",
         ":variable_ops",
@@ -6241,6 +6251,7 @@
         "scan_ops.h",
         "scatter_functor.h",
         "scatter_nd_op.h",
+        "scatter_nd_util.h",
         "segment_reduction_ops.h",
         "segment_reduction_ops_impl.h",
         "softplus_op.h",
@@ -6521,6 +6532,7 @@
         "scatter_nd_op_cpu_impl_5.cc",
         "scatter_nd_op_cpu_impl_6.cc",
         "scatter_nd_op_cpu_impl_7.cc",
+        "scatter_nd_util.cc",
         "segment_reduction_ops_impl_1.cc",
         "segment_reduction_ops_impl_2.cc",
         "segment_reduction_ops_impl_3.cc",
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 82f2421..e24eed3 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -21,8 +21,6 @@
 #include "tensorflow/core/platform/stream_executor.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-#include "tensorflow/core/kernels/scatter_nd_op.h"
-
 #include "tensorflow/core/framework/bounds_check.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -32,6 +30,8 @@
 #include "tensorflow/core/kernels/dense_update_functor.h"
 #include "tensorflow/core/kernels/fill_functor.h"
 #include "tensorflow/core/kernels/inplace_ops_functor.h"
+#include "tensorflow/core/kernels/scatter_nd_op.h"
+#include "tensorflow/core/kernels/scatter_nd_util.h"
 #include "tensorflow/core/kernels/training_op_helpers.h"
 #include "tensorflow/core/kernels/variable_ops.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -40,7 +40,6 @@
 #include "tensorflow/core/util/determinism.h"
 #include "tensorflow/core/util/util.h"
 
-
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -773,47 +772,6 @@
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 namespace functor {
-// Check whether updates.shape = indices.shape[:batch_dim] +
-// params_shape[slice_dim:]
-Status ValidateUpdateShape(const TensorShape& params_shape,
-                           const Tensor& indices, const Tensor& updates) {
-  const int64_t slice_dim =
-      (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
-  const int64_t batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
-
-  auto shape_err_prefix = [&]() {
-    return errors::InvalidArgument(
-        "Dimensions [0,", batch_dim,
-        ") of indices[shape=", indices.shape().DebugString(),
-        "] must match dimensions [0,", batch_dim,
-        ") of updates[shape=", updates.shape().DebugString(), "]");
-  };
-  auto shape_err_suffix = [&]() {
-    return errors::InvalidArgument(
-        "Dimensions [", slice_dim, ",", params_shape.dims(),
-        ") of input[shape=", params_shape.DebugString(),
-        "] must match dimensions [", slice_dim, ",", updates.dims(),
-        ") of updates[shape=", updates.shape().DebugString(), "]");
-  };
-
-  if (updates.dims() < batch_dim) return shape_err_prefix();
-  if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
-    return shape_err_suffix();
-  }
-  if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
-    return shape_err_suffix();
-  }
-  for (int d = 0; d < batch_dim; ++d) {
-    if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix();
-  }
-  for (int d = 0; d < updates.dims() - batch_dim; ++d) {
-    if (updates.dim_size(d + batch_dim) !=
-        params_shape.dim_size(d + slice_dim)) {
-      return shape_err_suffix();
-    }
-  }
-  return Status::OK();
-}
 
 template <typename Index>
 Status PrepareAndValidateInputs(const TensorShape& params_shape,
@@ -842,7 +800,8 @@
         "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
         "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
   }
-  TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
+  TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(),
+                                                  updates.shape()));
 
   // Check that we have enough index space
   const int64_t N_big = indices.NumElements();
diff --git a/tensorflow/core/kernels/scatter_nd_util.cc b/tensorflow/core/kernels/scatter_nd_util.cc
new file mode 100644
index 0000000..a0ad236
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/scatter_nd_util.h"
+
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+Status ValidateScatterNdUpdateShape(const TensorShape& params_shape,
+                                    const TensorShape& indices_shape,
+                                    const TensorShape& updates_shape) {
+  const int64_t slice_dim =
+      (indices_shape.dims() > 1)
+          ? indices_shape.dim_size(indices_shape.dims() - 1)
+          : 1;
+  const int64_t batch_dim =
+      (indices_shape.dims() > 1) ? indices_shape.dims() - 1 : 1;
+
+  auto shape_err_prefix = [&]() {
+    return errors::InvalidArgument(
+        "Dimensions [0,", batch_dim,
+        ") of indices[shape=", indices_shape.DebugString(),
+        "] must match dimensions [0,", batch_dim,
+        ") of updates[shape=", updates_shape.DebugString(), "]");
+  };
+  auto shape_err_suffix = [&]() {
+    return errors::InvalidArgument(
+        "Dimensions [", slice_dim, ",", params_shape.dims(),
+        ") of input[shape=", params_shape.DebugString(),
+        "] must match dimensions [", slice_dim, ",", updates_shape.dims(),
+        ") of updates[shape=", updates_shape.DebugString(), "]");
+  };
+
+  if (updates_shape.dims() < batch_dim) return shape_err_prefix();
+  if (params_shape.dims() < slice_dim + (updates_shape.dims() - batch_dim)) {
+    return shape_err_suffix();
+  }
+  if (updates_shape.dims() != batch_dim + params_shape.dims() - slice_dim) {
+    return shape_err_suffix();
+  }
+  for (int d = 0; d < batch_dim; ++d) {
+    if (updates_shape.dim_size(d) != indices_shape.dim_size(d))
+      return shape_err_prefix();
+  }
+  for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
+    if (updates_shape.dim_size(d + batch_dim) !=
+        params_shape.dim_size(d + slice_dim)) {
+      return shape_err_suffix();
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_nd_util.h b/tensorflow/core/kernels/scatter_nd_util.h
new file mode 100644
index 0000000..c2285f8
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_util.h
@@ -0,0 +1,30 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+// Validates the input shapes for the ScatterNdUpdateOp<scatter_nd_op::UpdateOp>
+Status ValidateScatterNdUpdateShape(const TensorShape& params_shape,
+                                    const TensorShape& indices_shape,
+                                    const TensorShape& updates_shape);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_
diff --git a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
index c5452ed..2e7d2eb 100644
--- a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py
@@ -102,6 +102,7 @@
   return _NumpyScatterNd(ref, indices, updates, np.maximum)
 
 
+@test_util.with_eager_op_as_function
 class StatefulScatterNdTest(test.TestCase):
 
   def _VariableRankTest(self,
@@ -271,6 +272,7 @@
   #     session.run([update0, update1])
   #     self.assertAllEqual([False, True], self.evaluate(var))
 
+  @test_util.disable_xla("b/205330448")
   def testScatterOutOfRangeCpu(self):
     # TODO(simister): Re-enable once binary size increase due to
     # scatter_nd ops is under control.
@@ -483,6 +485,7 @@
       self.assertAllEqual(val, val2)
 
 
+@test_util.with_eager_op_as_function
 class ScatterNdTest(test.TestCase, parameterized.TestCase):
   non_aliasing_add_test = False
 
@@ -796,6 +799,11 @@
     # Not supported yet.
     pass
 
+  # TODO(testString): Enable this test when the above testString is enabled.
+  def testStringWithEagerOpAsFunctionEnabled(self):
+    # Not supported yet.
+    pass
+
 
 class ScatterNdDeterminismTest(ScatterNdTest):