[XLA:SPMD] Use subgroup AllToAll for resharding

Reshard from tile [2,2,1] to [1,2,2] can be done by a subgroup all-to-all between dimensions 0 and 2.

PiperOrigin-RevId: 320720720
Change-Id: I1b63ba731b830610596c77697c5577fa9e2e0f79
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index 7e136be..1b484e0 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -176,16 +176,45 @@
   return groups;
 }
 
-bool CanReshardWithAllToAll(const HloSharding& source,
-                            const HloSharding& target) {
-  return UniqueTiledDim(source) && UniqueTiledDim(target) &&
-         UniqueTiledDim(source) != UniqueTiledDim(target);
+absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
+    const HloSharding& source, const HloSharding& target) {
+  if (source.IsTileMaximal() || target.IsTileMaximal() ||
+      source.tile_assignment().num_dimensions() !=
+          target.tile_assignment().num_dimensions()) {
+    return absl::nullopt;
+  }
+  int64 source_dim = -1;
+  int64 target_dim = -1;
+  for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
+    if (source.tile_assignment().dim(i) > 1 &&
+        target.tile_assignment().dim(i) == 1) {
+      if (source_dim != -1) {
+        return absl::nullopt;
+      }
+      source_dim = i;
+    } else if (source.tile_assignment().dim(i) == 1 &&
+               target.tile_assignment().dim(i) > 1) {
+      if (target_dim != -1) {
+        return absl::nullopt;
+      }
+      target_dim = i;
+    } else if (source.tile_assignment().dim(i) !=
+               target.tile_assignment().dim(i)) {
+      return absl::nullopt;
+    }
+  }
+  if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
+    return absl::nullopt;
+  }
+  return std::pair(source_dim, target_dim);
 }
 
 bool CanReshardWithCollectivePermute(const HloSharding& source,
                                      const HloSharding& target) {
-  return UniqueTiledDim(source) && UniqueTiledDim(target) &&
-         UniqueTiledDim(source) == UniqueTiledDim(target) && source != target;
+  return !source.IsTileMaximal() && !target.IsTileMaximal() &&
+         source.tile_assignment().dimensions() ==
+             target.tile_assignment().dimensions() &&
+         source.tile_assignment() != target.tile_assignment();
 }
 
 // Clears all sharding attributes from instructions in the module. This must be
@@ -278,8 +307,10 @@
     return ReshardWithCollectivePermute(target);
   }
 
-  if (CanReshardWithAllToAll(sharding(), target)) {
-    return ReshardWithAllToAll(target);
+  if (auto src_tgt_dims =
+          GetReshardAllToAllSourceTargetDims(sharding(), target)) {
+    return ReshardWithAllToAll(target, src_tgt_dims->first,
+                               src_tgt_dims->second);
   }
 
   // If not replicated yet, first replicate and then reshard to use one of the
@@ -745,45 +776,53 @@
   return PartitionedHlo(result, base_shape_, state_);
 }
 
-PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
-    const HloSharding& target) const {
-  int64 partition_count = sharding().tile_assignment().num_elements();
-  absl::optional<int64> input_partition_dim = UniqueTiledDim(sharding());
-  absl::optional<int64> output_partition_dim = UniqueTiledDim(target);
-  CHECK(input_partition_dim.has_value());
-  CHECK(output_partition_dim.has_value());
+PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
+                                                   int64 source_dim,
+                                                   int64 target_dim) const {
+  const int64 group_size = sharding().tile_assignment().dim(source_dim);
 
   // If the device order is different in the target, fix the order with
   // ReshardWithCollectivePermute.
-  auto input_tile_fixed_device_order = target.tile_assignment();
-  input_tile_fixed_device_order.Reshape(
-      sharding().tile_assignment().dimensions());
+  std::vector<int64> xpose_dims(target.tile_assignment().num_dimensions());
+  std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
+  xpose_dims[source_dim] = target_dim;
+  xpose_dims[target_dim] = source_dim;
   auto input_sharding_fixed_device_order =
-      HloSharding::Tile(input_tile_fixed_device_order);
+      hlo_sharding_util::TransposeSharding(target, xpose_dims);
   if (input_sharding_fixed_device_order != sharding()) {
     auto fixed_order =
         ReshardWithCollectivePermute(input_sharding_fixed_device_order);
-    return fixed_order.ReshardWithAllToAll(target);
+    return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim);
   }
 
   auto padded_hlo =
       PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
 
   // The order of ids in the group must follow the target sharding.
-  std::vector<ReplicaGroup> groups(1);
-  for (int64 device : target.tile_assignment()) {
-    groups[0].add_replica_ids(device);
-  }
+  std::vector<ReplicaGroup> groups(target.tile_assignment().num_elements() /
+                                   group_size);
+  target.tile_assignment().Each(
+      [&](absl::Span<const int64> indices, int64 device) {
+        int64 group_id = 0;
+        for (int64 dim = 0; dim < indices.size(); ++dim) {
+          if (dim == target_dim) {
+            continue;
+          }
+          group_id *= target.tile_assignment().dim(dim);
+          group_id += indices[dim];
+        }
+        groups[group_id].add_replica_ids(device);
+      });
 
   HloInstruction* result = nullptr;
 
-  // Split along the split dimension (output_partition_dim) of the all-to-all
+  // Split along the split dimension (target_dim) of the all-to-all
   // output.
   std::vector<int64> dimensions;
   for (int64 i = 0; i < base_shape_.rank(); ++i) {
-    if (i == *output_partition_dim) {
-      dimensions.push_back(partition_count);
-      dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count);
+    if (i == target_dim) {
+      dimensions.push_back(group_size);
+      dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
     } else {
       dimensions.push_back(padded_hlo->shape().dimensions(i));
     }
@@ -794,21 +833,19 @@
   // After the reshape, it is guaranteed to have at least 3 dimensions.
   auto all_to_all =
       state_.collective_ops_creator.create_cross_partition_all_to_all(
-          state_.b, {reshape}, groups, (*state_.next_channel_id)++,
-          output_partition_dim);
+          state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
 
   // Reorder the split dimension of the reshape to be located in front of the
   // input partition dimension, so the two dimensions can be combined.
-  int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim)
-                                      ? *input_partition_dim + 1
-                                      : *input_partition_dim;
+  int64 new_source_dim =
+      (target_dim < source_dim) ? source_dim + 1 : source_dim;
   std::vector<int64> permutation;
   for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
-    if (i == *output_partition_dim) {
+    if (i == target_dim) {
       continue;
     }
-    if (i == new_input_partition_dim) {
-      permutation.push_back(*output_partition_dim);
+    if (i == new_source_dim) {
+      permutation.push_back(target_dim);
     }
     permutation.push_back(i);
   }
@@ -819,8 +856,7 @@
 
   // Combine the split dimension and the input partition dimension.
   auto new_shape = ShapeInference::InferAllToAllShape(
-                       padded_hlo->shape(), *output_partition_dim,
-                       *input_partition_dim, partition_count)
+                       padded_hlo->shape(), target_dim, source_dim, group_size)
                        .ValueOrDie();
   result = state_.b->AddInstruction(
       HloInstruction::CreateReshape(new_shape, transpose));
@@ -837,7 +873,8 @@
 
 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
     const HloSharding& target) const {
-  CHECK(CanReshardWithCollectivePermute(sharding(), target));
+  CHECK(CanReshardWithCollectivePermute(sharding(), target))
+      << sharding().ToString() << " to " << target.ToString();
   std::vector<std::pair<int64, int64>> src_dst_pairs;
   sharding().tile_assignment().Each(
       [&](absl::Span<const int64> indices, int64 src_device) {
@@ -3653,8 +3690,8 @@
         output_batch_partitions == num_partitions_ &&
         lhs_sharding_transposed_to_match_output == hlo->sharding()) {
       if (!may_reshard_with_allreduce &&
-          !CanReshardWithAllToAll(rhs.sharding(),
-                                  *lhs_sharding_transposed_to_match_rhs)) {
+          !GetReshardAllToAllSourceTargetDims(
+              rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
         return false;
       }
       auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
@@ -3668,8 +3705,8 @@
         output_batch_partitions == num_partitions_ &&
         rhs_sharding_transposed_to_match_output == hlo->sharding()) {
       if (!may_reshard_with_allreduce &&
-          !CanReshardWithAllToAll(lhs.sharding(),
-                                  *rhs_sharding_transposed_to_match_lhs)) {
+          !GetReshardAllToAllSourceTargetDims(
+              lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
         return false;
       }
       auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index 52e4c90..40881b4 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -284,7 +284,8 @@
 
   // Helper function to reshard the tensor using AllToAll (instead of the
   // default of Replicate followed by Slice).
-  PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const;
+  PartitionedHlo ReshardWithAllToAll(const HloSharding& target,
+                                     int64 source_dim, int64 target_dim) const;
 
   // Helper function to reshard the tensor using CollectivePermute.
   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index 1f0b1d0..9f3708f 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -3766,6 +3766,32 @@
                                        op::Parameter(0))));
 }
 
+TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8,8,8] parameter(0),
+    sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
+  ROOT %copy = f32[8,8,8,8] copy(%param0),
+    sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/2));
+  VLOG(1) << module->ToString();
+
+  auto root = module->entry_computation()->root_instruction();
+  auto reshape =
+      AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
+  auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
+  auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
+  EXPECT_THAT(root,
+              op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
+  EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
+            4);
+}
+
 }  // namespace
 }  // namespace spmd
 }  // namespace xla