[XLA] [DynamicPadder] Support dynamic dimension in scatter indices.
- Pad Scatter's indices to -1, which is a no-op when applied.
- Add execution test to make sure the result is the same as if not padded.
PiperOrigin-RevId: 268804543
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 51d45bf..46d014f 100755
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2290,7 +2290,7 @@
],
)
-tf_cc_test(
+xla_test(
name = "dynamic_padder_test",
srcs = ["dynamic_padder_test.cc"],
deps = [
@@ -2307,7 +2307,9 @@
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index ea73ca1..e02a582 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -833,20 +833,13 @@
int64 operand_index, HloInstruction* operand_dynamic_size,
DimensionConstraint constraint) {
if (operand_index == 0) {
- return Unimplemented(
- "Detects a dynamic dimension on the data input of scatter, which "
- "is not supported: %s",
- hlo->ToString());
- }
-
- const ScatterDimensionNumbers& scatter_dims =
- hlo->scatter_dimension_numbers();
- if (operand_index == 1) {
parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size,
constraint);
return Status::OK();
}
+ const ScatterDimensionNumbers& scatter_dims =
+ hlo->scatter_dimension_numbers();
if (operand_index == 2 &&
absl::c_linear_search(scatter_dims.update_window_dims(),
dimension)) {
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
index 12af09f..e8e89c8 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h
@@ -164,6 +164,8 @@
// by a scalar instruction `size`.
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
HloInstruction* size, DimensionConstraint constraint) {
+ VLOG(1) << "Set dimension inst " << inst->name() << " index "
+ << index.ToString() << "@" << dim << " to " << size->ToString();
Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
CHECK(!subshape.IsTuple())
<< "Can't set a tuple shape to dynamic dimension";
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index d9a97a8..0f3f1eb 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -78,9 +78,27 @@
case HloOpcode::kSelectAndScatter: {
return inst->mutable_operand(2);
}
+ case HloOpcode::kScatter: {
+ PrimitiveType ptype = inst->shape().element_type();
+ if (operand_number != 1) {
+ return nullptr;
+ }
+ // Use -1 as padding for scatter as output bound updates are not applied.
+ switch (ptype) {
+ case S32:
+ return comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(-1)));
+ case S64:
+ return comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(-1)));
+ default:
+ return InvalidArgument(
+ "Invalid primitive type %s",
+ primitive_util::LowercasePrimitiveTypeName(ptype));
+ }
+ }
case HloOpcode::kParameter:
case HloOpcode::kGather:
- case HloOpcode::kScatter:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kGetDimensionSize:
@@ -128,17 +146,19 @@
for (HloInstruction* inst : computation->instructions()) {
for (int64 operand_num = 0; operand_num < inst->operand_count();
++operand_num) {
- HloInstruction* operand = inst->mutable_operand(operand_num);
+ HloInstruction* original_operand = inst->mutable_operand(operand_num);
+ HloInstruction* operand = original_operand;
if (!operand->shape().IsArray()) {
continue;
}
for (int64 dim = 0; dim < operand->shape().rank(); ++dim) {
HloInstruction* dynamic_size =
- dynamic_dimension_inference.GetDynamicSize(operand, {}, dim);
+ dynamic_dimension_inference.GetDynamicSize(original_operand, {},
+ dim);
if (dynamic_size == nullptr) {
continue;
}
- VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @"
+ VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
<< dim;
if (ShouldSkipPadOnOperand(inst, operand_num, dim)) {
diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
index 4cc3673..88f2460 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc
@@ -28,7 +28,10 @@
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -212,5 +215,133 @@
EXPECT_THAT(output->operand(0), op::Parameter());
}
+// Test that dynamic padder has the same result as if not padded.
+class ExecutionTest : public HloTestBase {
+ protected:
+ std::unique_ptr<HloModule> GetHloModule(const string& hlo_text) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ std::unique_ptr<HloModule> module =
+ ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie();
+ return module;
+ }
+};
+
+XLA_TEST_F(ExecutionTest, ScatterUpdate) {
+ // Test that scattering on indices=[2] is same as scattering on indices=[4]
+ // and dynamic dimension = 2
+ const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[INDICES_BOUND] parameter(1)
+ updates = s32[INDICES_BOUND,3] parameter(2)
+ dynamic_size = s32[] parameter(3)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+
+}
+)";
+ const string hlo_text_not_padded =
+ absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
+ auto module_not_padded = GetHloModule(hlo_text_not_padded);
+
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal dynamic_size = LiteralUtil::CreateR0<int32>(2);
+
+ Literal not_padded =
+ ExecuteAndTransfer(std::move(module_not_padded),
+ {&operand, &scatter_indices, &updates, &dynamic_size});
+
+ // Pad input to 4.
+ const string hlo_text_padded =
+ absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
+ auto module_padded = GetHloModule(hlo_text_padded);
+ // Set up dynamic parameter binding.
+ TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+ DynamicParameterBinding::DynamicParameter{3, {}},
+ DynamicParameterBinding::DynamicDimension{1, {}, 0}));
+ TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+ DynamicParameterBinding::DynamicParameter{3, {}},
+ DynamicParameterBinding::DynamicDimension{2, {}, 0}));
+ // Pad the rest of input with garbage data.
+ Literal scatter_indices_padded = LiteralUtil::CreateR1<int32>({0, 2, 0, 4});
+ Literal updates_padded = LiteralUtil::CreateR2<int32>(
+ {{10, 20, 30}, {70, 80, 90}, {30, 22, 11}, {-1, 20, -1}});
+ DynamicPadder padder;
+ TF_CHECK_OK(padder.Run(module_padded.get()).status());
+ Literal padded = ExecuteAndTransfer(
+ std::move(module_padded),
+ {&operand, &scatter_indices_padded, &updates_padded, &dynamic_size});
+
+ EXPECT_EQ(padded, not_padded);
+}
+
+XLA_TEST_F(ExecutionTest, TwoDimensionReduce) {
+ // Test that reducing on operand=[2,2] is same as reducing on operand=[4,4]
+ // and dynamic dimension = 2
+ const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+}
+
+ENTRY main {
+ param = s32[INDICES_BOUND, INDICES_BOUND] parameter(0)
+ dynamic_size = s32[] parameter(1)
+ const = s32[] constant(0)
+ ROOT reduce = s32[] reduce(param, const),
+ dimensions={0, 1},
+ to_apply=update_s32
+}
+)";
+ const string hlo_text_not_padded =
+ absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}});
+ auto module_not_padded = GetHloModule(hlo_text_not_padded);
+
+ Literal operand = LiteralUtil::CreateR2<int32>({{1, 2}, {4, 5}});
+ Literal dynamic_size = LiteralUtil::CreateR0<int32>(2);
+
+ Literal not_padded = ExecuteAndTransfer(std::move(module_not_padded),
+ {&operand, &dynamic_size});
+
+ // Pad input to 4.
+ const string hlo_text_padded =
+ absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}});
+ auto module_padded = GetHloModule(hlo_text_padded);
+ // Set up dynamic parameter binding.
+ TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+ DynamicParameterBinding::DynamicParameter{1, {}},
+ DynamicParameterBinding::DynamicDimension{0, {}, 0}));
+ TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind(
+ DynamicParameterBinding::DynamicParameter{1, {}},
+ DynamicParameterBinding::DynamicDimension{0, {}, 1}));
+ // Pad the rest of input with garbage data.
+ Literal operand_padded = LiteralUtil::CreateR2<int32>(
+ {{1, 2, 3, 4}, {4, 5, 6, 7}, {1, 2, 3, 4}, {4, 5, 6, 7}});
+ DynamicPadder padder;
+ TF_CHECK_OK(padder.Run(module_padded.get()).status());
+ Literal padded = ExecuteAndTransfer(std::move(module_padded),
+ {&operand_padded, &dynamic_size});
+
+ EXPECT_EQ(padded, not_padded);
+}
+
} // namespace
} // namespace xla