[XLA] Fix small bug in the the add(scatter,scatter) transformation when the index_vector dim was not collapsed. Also avoid the transformation when the indices are a scalar.

PiperOrigin-RevId: 305773603
Change-Id: I55ee1f0a7741405abd4f044aa5dd84e16782e676
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index ba0795f..f605539 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -740,14 +740,14 @@
       return Status::OK();
     }
 
-    int64 first_index_dim;
-    int64 first_update_dim;
+    int64 first_index_dim = lhs_scatter_index->shape().rank();
+    int64 first_update_dim = lhs_scatter_update->shape().rank();
     // Find a dimension where it is possible to concatenate the indices and
     // updates. This is the first and only non-equal dimension or the first
     // equally sized dimension.
     for (int64 d = lhs_scatter_index->shape().rank() - 1,
                update_dim = lhs_scatter_update->shape().rank() - 1;
-         d >= 0; --d, --update_dim) {
+         d >= 0; --d) {
       if (d == lhs_dnums.index_vector_dim()) {
         continue;
       }
@@ -758,7 +758,7 @@
       if (lhs_scatter_index->shape().dimensions(d) ==
           rhs_scatter_index->shape().dimensions(d)) {
         first_index_dim = d;
-        first_update_dim = update_dim;
+        first_update_dim = update_dim--;
         continue;
       }
       // More than one dimension of unequal size was found, bail out.
@@ -766,12 +766,18 @@
         return Status::OK();
       }
       index_concat_dimension = d;
-      update_concat_dimension = update_dim;
+      update_concat_dimension = update_dim--;
     }
     if (!index_concat_dimension) {
       index_concat_dimension = first_index_dim;
       update_concat_dimension = first_update_dim;
     }
+
+    // A scalar scatter will require additional reshapes of the index and
+    // update.
+    if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
+      return Status::OK();
+    }
     const bool update_concat_is_cheap =
         ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
             ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index bda60c4..9bbf692 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -6348,5 +6348,83 @@
                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
 }
 
+TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
+  const char* hlo_string = R"(
+  HloModule m
+  apply {
+   a = f32[] parameter(0)
+   b = f32[] parameter(1)
+   ROOT c = f32[] add(a, b)
+  }
+  test {
+    z  = f32[] constant(0)
+    init = f32[100,4] broadcast(z), dimensions={}
+    shared = f32[100,4] parameter(0)
+    index0 = s32[4,3,1] parameter(1)
+    index1 = s32[4,5,1] parameter(2)
+    update0 = f32[4,4,3] parameter(3)
+    update1 = f32[4,4,5] parameter(4)
+    scatter.0 = f32[100,4] scatter(init, index0, update0),
+              to_apply=apply,
+              update_window_dims={0},
+              inserted_window_dims={0},
+              scatter_dims_to_operand_dims={0},
+              index_vector_dim=2
+    scatter.1 = f32[100,4] scatter(init, index1, update1),
+              to_apply=apply,
+              update_window_dims={0},
+              inserted_window_dims={0},
+              scatter_dims_to_operand_dims={0},
+              index_vector_dim=2
+    ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  // Combine Scatters
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  // Simplify Add
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Scatter(m::Broadcast(),
+                            m::Concatenate(m::Parameter(1), m::Parameter(2)),
+                            m::Concatenate(m::Parameter(3), m::Parameter(4)))));
+}
+
+TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
+  const char* hlo_string = R"(
+  HloModule m
+  apply {
+   a = f32[] parameter(0)
+   b = f32[] parameter(1)
+   ROOT c = f32[] add(a, b)
+  }
+  test {
+    z  = f32[] constant(0)
+    init = f32[100,4,20] broadcast(z), dimensions={}
+    shared = f32[100,4,20] parameter(0)
+    index0 = s32[1] parameter(1)
+    index1 = s32[1] parameter(2)
+    update0 = f32[4,20] parameter(3)
+    update1 = f32[4,20] parameter(4)
+    scatter.0 = f32[100,4,20] scatter(init, index0, update0),
+              to_apply=apply,
+              update_window_dims={0, 1},
+              inserted_window_dims={0},
+              scatter_dims_to_operand_dims={0},
+              index_vector_dim=0
+    scatter.1 = f32[100,4,20] scatter(init, index1, update1),
+              to_apply=apply,
+              update_window_dims={0, 1},
+              inserted_window_dims={0},
+              scatter_dims_to_operand_dims={0},
+              index_vector_dim=0
+    ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  // Combine Scatters
+  ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+}
 }  // namespace
 }  // namespace xla