[XLA] Fix small bug in the the add(scatter,scatter) transformation when the index_vector dim was not collapsed. Also avoid the transformation when the indices are a scalar.
PiperOrigin-RevId: 305773603
Change-Id: I55ee1f0a7741405abd4f044aa5dd84e16782e676
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index ba0795f..f605539 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -740,14 +740,14 @@
return Status::OK();
}
- int64 first_index_dim;
- int64 first_update_dim;
+ int64 first_index_dim = lhs_scatter_index->shape().rank();
+ int64 first_update_dim = lhs_scatter_update->shape().rank();
// Find a dimension where it is possible to concatenate the indices and
// updates. This is the first and only non-equal dimension or the first
// equally sized dimension.
for (int64 d = lhs_scatter_index->shape().rank() - 1,
update_dim = lhs_scatter_update->shape().rank() - 1;
- d >= 0; --d, --update_dim) {
+ d >= 0; --d) {
if (d == lhs_dnums.index_vector_dim()) {
continue;
}
@@ -758,7 +758,7 @@
if (lhs_scatter_index->shape().dimensions(d) ==
rhs_scatter_index->shape().dimensions(d)) {
first_index_dim = d;
- first_update_dim = update_dim;
+ first_update_dim = update_dim--;
continue;
}
// More than one dimension of unequal size was found, bail out.
@@ -766,12 +766,18 @@
return Status::OK();
}
index_concat_dimension = d;
- update_concat_dimension = update_dim;
+ update_concat_dimension = update_dim--;
}
if (!index_concat_dimension) {
index_concat_dimension = first_index_dim;
update_concat_dimension = first_update_dim;
}
+
+ // A scalar scatter will require additional reshapes of the index and
+ // update.
+ if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
+ return Status::OK();
+ }
const bool update_concat_is_cheap =
ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index bda60c4..9bbf692 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -6348,5 +6348,83 @@
m::Concatenate(m::Parameter(3), m::Parameter(4)))));
}
+TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
+ const char* hlo_string = R"(
+ HloModule m
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a, b)
+ }
+ test {
+ z = f32[] constant(0)
+ init = f32[100,4] broadcast(z), dimensions={}
+ shared = f32[100,4] parameter(0)
+ index0 = s32[4,3,1] parameter(1)
+ index1 = s32[4,5,1] parameter(2)
+ update0 = f32[4,4,3] parameter(3)
+ update1 = f32[4,4,5] parameter(4)
+ scatter.0 = f32[100,4] scatter(init, index0, update0),
+ to_apply=apply,
+ update_window_dims={0},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=2
+ scatter.1 = f32[100,4] scatter(init, index1, update1),
+ to_apply=apply,
+ update_window_dims={0},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=2
+ ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ // Combine Scatters
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ // Simplify Add
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(
+ m->entry_computation()->root_instruction(),
+ GmockMatch(m::Scatter(m::Broadcast(),
+ m::Concatenate(m::Parameter(1), m::Parameter(2)),
+ m::Concatenate(m::Parameter(3), m::Parameter(4)))));
+}
+
+TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
+ const char* hlo_string = R"(
+ HloModule m
+ apply {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT c = f32[] add(a, b)
+ }
+ test {
+ z = f32[] constant(0)
+ init = f32[100,4,20] broadcast(z), dimensions={}
+ shared = f32[100,4,20] parameter(0)
+ index0 = s32[1] parameter(1)
+ index1 = s32[1] parameter(2)
+ update0 = f32[4,20] parameter(3)
+ update1 = f32[4,20] parameter(4)
+ scatter.0 = f32[100,4,20] scatter(init, index0, update0),
+ to_apply=apply,
+ update_window_dims={0, 1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+ scatter.1 = f32[100,4,20] scatter(init, index1, update1),
+ to_apply=apply,
+ update_window_dims={0, 1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+ ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+ // Combine Scatters
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+}
} // namespace
} // namespace xla