[XLA] [DynamicPadder] Support dynamic dimension in scatter indices.

- Pad Scatter's indices to -1, which is a no-op when applied.
- Add execution test to make sure the result is the same as if not padded.

PiperOrigin-RevId: 268804543
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 51d45bf..46d014f 100755
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2290,7 +2290,7 @@
     ],
 )
 
-tf_cc_test(
+xla_test(
     name = "dynamic_padder_test",
     srcs = ["dynamic_padder_test.cc"],
     deps = [
@@ -2307,7 +2307,9 @@
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/tests:client_library_test_base",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/core:test",
     ],
 )
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index ea73ca1..e02a582 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -833,20 +833,13 @@
           int64 operand_index, HloInstruction* operand_dynamic_size,
           DimensionConstraint constraint) {
         if (operand_index == 0) {
-          return Unimplemented(
-              "Detects a dynamic dimension on the data input of scatter, which "
-              "is not supported: %s",
-              hlo->ToString());
-        }
-
-        const ScatterDimensionNumbers& scatter_dims =
-            hlo->scatter_dimension_numbers();
-        if (operand_index == 1) {
           parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size,
                                   constraint);
           return Status::OK();
         }
 
+        const ScatterDimensionNumbers& scatter_dims =
+            hlo->scatter_dimension_numbers();
         if (operand_index == 2 &&
             absl::c_linear_search(scatter_dims.update_window_dims(),
                                   dimension)) {
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
index 12af09f..e8e89c8 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
@@ -164,6 +164,8 @@
   // by a scalar instruction `size`.
   void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
                       HloInstruction* size, DimensionConstraint constraint) {
+    VLOG(1) << "Set dimension inst " << inst->name() << " index "
+            << index.ToString() << "@" << dim << " to " << size->ToString();
     Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
     CHECK(!subshape.IsTuple())
         << "Can't set a tuple shape to dynamic dimension";
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index d9a97a8..0f3f1eb 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -78,9 +78,27 @@
     case HloOpcode::kSelectAndScatter: {
       return inst->mutable_operand(2);
     }
+    case HloOpcode::kScatter: {
+      PrimitiveType ptype = inst->shape().element_type();
+      if (operand_number != 1) {
+        return nullptr;
+      }
+      // Use -1 as padding for scatter as output bound updates are not applied.
+      switch (ptype) {
+        case S32:
+          return comp->AddInstruction(
+              HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(-1)));
+        case S64:
+          return comp->AddInstruction(
+              HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(-1)));
+        default:
+          return InvalidArgument(
+              "Invalid primitive type %s",
+              primitive_util::LowercasePrimitiveTypeName(ptype));
+      }
+    }
     case HloOpcode::kParameter:
     case HloOpcode::kGather:
-    case HloOpcode::kScatter:
     case HloOpcode::kDynamicSlice:
     case HloOpcode::kDynamicUpdateSlice:
     case HloOpcode::kGetDimensionSize:
@@ -128,17 +146,19 @@
     for (HloInstruction* inst : computation->instructions()) {
       for (int64 operand_num = 0; operand_num < inst->operand_count();
            ++operand_num) {
-        HloInstruction* operand = inst->mutable_operand(operand_num);
+        HloInstruction* original_operand = inst->mutable_operand(operand_num);
+        HloInstruction* operand = original_operand;
         if (!operand->shape().IsArray()) {
           continue;
         }
         for (int64 dim = 0; dim < operand->shape().rank(); ++dim) {
           HloInstruction* dynamic_size =
-              dynamic_dimension_inference.GetDynamicSize(operand, {}, dim);
+              dynamic_dimension_inference.GetDynamicSize(original_operand, {},
+                                                         dim);
           if (dynamic_size == nullptr) {
             continue;
           }
-          VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @"
+          VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
                   << dim;
 
           if (ShouldSkipPadOnOperand(inst, operand_num, dim)) {
diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
index 4cc3673..88f2460 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
@@ -28,7 +28,10 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test_benchmark.h"
@@ -212,5 +215,133 @@
   EXPECT_THAT(output->operand(0), op::Parameter());
 }
 
+// Test that dynamic padder has the same result as if not padded.
+class ExecutionTest : public HloTestBase {
+ protected:
+  std::unique_ptr<HloModule> GetHloModule(const string& hlo_text) {
+    HloModuleConfig config;
+    config.set_debug_options(GetDebugOptionsForTest());
+    std::unique_ptr<HloModule> module =
+        ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie();
+    return module;
+  }
+};
+
+XLA_TEST_F(ExecutionTest, ScatterUpdate) {
+  // Test that scattering on indices=[2] is same as scattering on indices=[4]
+  // and dynamic dimension = 2
+  const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+  lhs = s32[] parameter(0)
+  ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[INDICES_BOUND] parameter(1)
+  updates = s32[INDICES_BOUND,3] parameter(2)
+  dynamic_size = s32[] parameter(3)
+  ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+      to_apply=update_s32,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1
+
+}
+)";
+  const string hlo_text_not_padded =
+      absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
+  auto module_not_padded = GetHloModule(hlo_text_not_padded);
+
+  Literal operand =
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  Literal dynamic_size = LiteralUtil::CreateR0<int32>(2);
+
+  Literal not_padded =
+      ExecuteAndTransfer(std::move(module_not_padded),
+                         {&operand, &scatter_indices, &updates, &dynamic_size});
+
+  // Pad input to 4.
+  const string hlo_text_padded =
+      absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
+  auto module_padded = GetHloModule(hlo_text_padded);
+  // Set up dynamic parameter binding.
+  TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+      DynamicParameterBinding::DynamicParameter{3, {}},
+      DynamicParameterBinding::DynamicDimension{1, {}, 0}));
+  TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+      DynamicParameterBinding::DynamicParameter{3, {}},
+      DynamicParameterBinding::DynamicDimension{2, {}, 0}));
+  // Pad the rest of input with garbage data.
+  Literal scatter_indices_padded = LiteralUtil::CreateR1<int32>({0, 2, 0, 4});
+  Literal updates_padded = LiteralUtil::CreateR2<int32>(
+      {{10, 20, 30}, {70, 80, 90}, {30, 22, 11}, {-1, 20, -1}});
+  DynamicPadder padder;
+  TF_CHECK_OK(padder.Run(module_padded.get()).status());
+  Literal padded = ExecuteAndTransfer(
+      std::move(module_padded),
+      {&operand, &scatter_indices_padded, &updates_padded, &dynamic_size});
+
+  EXPECT_EQ(padded, not_padded);
+}
+
+XLA_TEST_F(ExecutionTest, TwoDimensionReduce) {
+  // Test that reducing on operand=[2,2] is same as reducing on operand=[4,4]
+  // and dynamic dimension = 2
+  const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+  lhs = s32[] parameter(0)
+  rhs = s32[] parameter(1)
+  ROOT add = s32[] add(lhs, rhs)
+}
+
+ENTRY main {
+  param = s32[INDICES_BOUND, INDICES_BOUND] parameter(0)
+  dynamic_size = s32[] parameter(1)
+  const = s32[] constant(0)
+  ROOT reduce = s32[] reduce(param, const),
+      dimensions={0, 1},
+      to_apply=update_s32
+}
+)";
+  const string hlo_text_not_padded =
+      absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
+  auto module_not_padded = GetHloModule(hlo_text_not_padded);
+
+  Literal operand = LiteralUtil::CreateR2<int32>({{1, 2}, {4, 5}});
+  Literal dynamic_size = LiteralUtil::CreateR0<int32>(2);
+
+  Literal not_padded = ExecuteAndTransfer(std::move(module_not_padded),
+                                          {&operand, &dynamic_size});
+
+  // Pad input to 4.
+  const string hlo_text_padded =
+      absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
+  auto module_padded = GetHloModule(hlo_text_padded);
+  // Set up dynamic parameter binding.
+  TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+      DynamicParameterBinding::DynamicParameter{1, {}},
+      DynamicParameterBinding::DynamicDimension{0, {}, 0}));
+  TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+      DynamicParameterBinding::DynamicParameter{1, {}},
+      DynamicParameterBinding::DynamicDimension{0, {}, 1}));
+  // Pad the rest of input with garbage data.
+  Literal operand_padded = LiteralUtil::CreateR2<int32>(
+      {{1, 2, 3, 4}, {4, 5, 6, 7}, {1, 2, 3, 4}, {4, 5, 6, 7}});
+  DynamicPadder padder;
+  TF_CHECK_OK(padder.Run(module_padded.get()).status());
+  Literal padded = ExecuteAndTransfer(std::move(module_padded),
+                                      {&operand_padded, &dynamic_size});
+
+  EXPECT_EQ(padded, not_padded);
+}
+
 }  // namespace
 }  // namespace xla