Relax the dynamic shape constraint in scatter update dimension.

Allows a scatter update dimension to be dynamic if it's the same as the corresponding operand dimension.

PiperOrigin-RevId: 361924492
Change-Id: I0e3fa1485911761427b755a4e7097c6e65383e19
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index f2ab2d7..4816d26 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -1450,9 +1450,43 @@
         if (operand_index == 2 &&
             absl::c_linear_search(scatter_dims.update_window_dims(),
                                   dimension)) {
-          return Unimplemented(
-              "Dynamic dimension of update window dims is not supported: %s",
-              hlo->ToString());
+          // Dynamic update window dimension is only allowed if it is exactly
+          // the same as the corresponding operand dimension.
+          std::vector<int64> update_window_dims_in_operand;
+          for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
+            if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
+              continue;
+            }
+            update_window_dims_in_operand.push_back(i);
+          }
+
+          for (int64 i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
+            if (scatter_dims.update_window_dims(i) == dimension) {
+              const Shape& operand_shape = hlo->operand(0)->shape();
+              const Shape& update_shape = hlo->operand(2)->shape();
+              int64 dim_in_operand = update_window_dims_in_operand[i];
+              if (operand_shape.dimensions(dim_in_operand) !=
+                      update_shape.dimensions(dimension) ||
+                  !operand_shape.is_dynamic_dimension(dim_in_operand)) {
+                return Unimplemented(
+                    "Dynamic dimension of update window dims that are not the "
+                    "same as corresponding operand dim is not supported: "
+                    "%s",
+                    hlo->ToString());
+              }
+              HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
+                  hlo->mutable_operand(0), {}, dim_in_operand);
+              if (base_dynamic_size != operand_dynamic_size) {
+                return Unimplemented(
+                    "Dynamic dimension size of update window dims that are not "
+                    "the same as corresponding operand dim is not supported: "
+                    "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
+                    "update: %s",
+                    hlo->ToString(), base_dynamic_size->ToString(),
+                    operand_dynamic_size->ToString());
+              }
+            }
+          }
         }
         // The dynamic dimension is collapsed and won't show up in the output.
         // Do nothing here.
diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
index e720670..584625c 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
@@ -458,6 +458,47 @@
   EXPECT_EQ(padded, not_padded);
 }
 
+XLA_TEST_F(ExecutionTest, ScatterUpdateWindowDim) {
+  const string hlo_text = R"(
+HloModule ScatterUpdateWindowDim
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+  lhs = s32[] parameter(0)
+  ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+  operand = s32[1,2,3] parameter(0)
+  indices = s32[1] parameter(1)
+  updates = s32[2,3,1] parameter(2)
+  dynamic_size = s32[] constant(1)
+  operand_dynamic = s32[1, <=2, 3] set-dimension-size(operand, dynamic_size),
+      dimensions={1}
+  updates_dynamic = s32[<=2, 3, 1] set-dimension-size(updates, dynamic_size),
+      dimensions={0}
+  ROOT scatter = s32[1, <=2, 3] scatter(operand_dynamic, indices, updates_dynamic),
+      to_apply=update_s32,
+      update_window_dims={0, 1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1
+
+}
+)";
+  auto hlo_module = GetHloModule(hlo_text);
+
+  Literal operand = LiteralUtil::CreateR3<int32>({{{0, 0, 0}, {0, 0, 0}}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0});
+  Literal updates =
+      LiteralUtil::CreateR3<int32>({{{10}, {20}, {30}}, {{70}, {80}, {90}}});
+
+  Literal padded = PadAndExecute(std::move(hlo_module),
+                                 {&operand, &scatter_indices, &updates}, false);
+  Literal expected =
+      LiteralUtil::CreateR3<int32>({{{10, 20, 30}, {70, 80, 90}}});
+  EXPECT_EQ(padded, expected);
+}
+
 XLA_TEST_F(ExecutionTest, ScatterUpdateF32) {
   // Test that scattering on indices=[2] is same as scattering on indices=[4]
   // and dynamic dimension = 2