Relax the dynamic shape constraint in scatter update dimension.
Allows a scatter update dimension to be dynamic if it's the same as the corresponding operand dimension.
PiperOrigin-RevId: 361924492
Change-Id: I0e3fa1485911761427b755a4e7097c6e65383e19
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index f2ab2d7..4816d26 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -1450,9 +1450,43 @@
if (operand_index == 2 &&
absl::c_linear_search(scatter_dims.update_window_dims(),
dimension)) {
- return Unimplemented(
- "Dynamic dimension of update window dims is not supported: %s",
- hlo->ToString());
+ // Dynamic update window dimension is only allowed if it is exactly
+ // the same as the corresponding operand dimension.
+ std::vector<int64> update_window_dims_in_operand;
+ for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
+ if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
+ continue;
+ }
+ update_window_dims_in_operand.push_back(i);
+ }
+
+ for (int64 i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
+ if (scatter_dims.update_window_dims(i) == dimension) {
+ const Shape& operand_shape = hlo->operand(0)->shape();
+ const Shape& update_shape = hlo->operand(2)->shape();
+ int64 dim_in_operand = update_window_dims_in_operand[i];
+ if (operand_shape.dimensions(dim_in_operand) !=
+ update_shape.dimensions(dimension) ||
+ !operand_shape.is_dynamic_dimension(dim_in_operand)) {
+ return Unimplemented(
+ "Dynamic dimension of update window dims that are not the "
+ "same as corresponding operand dim is not supported: "
+ "%s",
+ hlo->ToString());
+ }
+ HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
+ hlo->mutable_operand(0), {}, dim_in_operand);
+ if (base_dynamic_size != operand_dynamic_size) {
+ return Unimplemented(
+ "Dynamic dimension size of update window dims that are not "
+ "the same as corresponding operand dim is not supported: "
+ "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
+ "update: %s",
+ hlo->ToString(), base_dynamic_size->ToString(),
+ operand_dynamic_size->ToString());
+ }
+ }
+ }
}
// The dynamic dimension is collapsed and won't show up in the output.
// Do nothing here.
diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
index e720670..584625c 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
@@ -458,6 +458,47 @@
EXPECT_EQ(padded, not_padded);
}
+XLA_TEST_F(ExecutionTest, ScatterUpdateWindowDim) {
+ const string hlo_text = R"(
+HloModule ScatterUpdateWindowDim
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[1,2,3] parameter(0)
+ indices = s32[1] parameter(1)
+ updates = s32[2,3,1] parameter(2)
+ dynamic_size = s32[] constant(1)
+ operand_dynamic = s32[1, <=2, 3] set-dimension-size(operand, dynamic_size),
+ dimensions={1}
+ updates_dynamic = s32[<=2, 3, 1] set-dimension-size(updates, dynamic_size),
+ dimensions={0}
+ ROOT scatter = s32[1, <=2, 3] scatter(operand_dynamic, indices, updates_dynamic),
+ to_apply=update_s32,
+ update_window_dims={0, 1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+
+}
+)";
+ auto hlo_module = GetHloModule(hlo_text);
+
+ Literal operand = LiteralUtil::CreateR3<int32>({{{0, 0, 0}, {0, 0, 0}}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0});
+ Literal updates =
+ LiteralUtil::CreateR3<int32>({{{10}, {20}, {30}}, {{70}, {80}, {90}}});
+
+ Literal padded = PadAndExecute(std::move(hlo_module),
+ {&operand, &scatter_indices, &updates}, false);
+ Literal expected =
+ LiteralUtil::CreateR3<int32>({{{10, 20, 30}, {70, 80, 90}}});
+ EXPECT_EQ(padded, expected);
+}
+
XLA_TEST_F(ExecutionTest, ScatterUpdateF32) {
// Test that scattering on indices=[2] is same as scattering on indices=[4]
// and dynamic dimension = 2