[XLA:SPMD] Support convolution with non contracting spatial dim partitioned at
batch dim.

PiperOrigin-RevId: 321579087
Change-Id: I42f86b9c281e02f8157287653aa30a54c14a0e72
diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
index 576d9d4..4670ce6 100644
--- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
+++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
@@ -24,6 +24,31 @@
 namespace xla {
 namespace dot_as_convolution_util {
 
+bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
+  // A parallel batch dimension in DotGeneral is represented as a
+  // spatial dimension with window size B (batch dimension size),
+  // stride B - 1, and base dilation B.
+  if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
+      ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
+        wd.window_dilation() == 1) ||
+       (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
+        wd.stride() == 1)) &&
+      wd.padding_high() == 0 && wd.padding_low() == 0 &&
+      !wd.window_reversal()) {
+    return true;
+  }
+
+  // Aternative representation of a batch dimension.
+  if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
+      wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
+      wd.window_dilation() == 1 && wd.stride() == lhs_size &&
+      wd.base_dilation() == lhs_size - 1) {
+    return true;
+  }
+
+  return false;
+}
+
 /* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
 ParseDotGeneralFromConvolution(const HloInstruction* conv) {
   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -49,22 +74,7 @@
     int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
     int64 output = conv_dims.output_spatial_dimensions(i);
     const auto& wd = conv->window().dimensions(i);
-    if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
-        ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
-          wd.window_dilation() == 1) ||
-         (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
-          wd.stride() == 1)) &&
-        wd.padding_high() == 0 && wd.padding_low() == 0 &&
-        !wd.window_reversal()) {
-      // A batch dimension in DotGeneral is represented as a spatial dimension
-      // with window size B (batch dimension size), stride B - 1, and base
-      // dilation B.
-      dims.batch_dims.push_back({lhs, rhs, output, i});
-    } else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
-               wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
-               wd.window_dilation() == 1 && wd.stride() == lhs_size &&
-               wd.base_dilation() == lhs_size - 1) {
-      // Aternative representation of a batch dimension.
+    if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
       dims.batch_dims.push_back({lhs, rhs, output, i});
     } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
                wd.window_dilation() == 1 && wd.padding_high() == 0 &&
diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h
index a3e829a..6a7cacf 100644
--- a/tensorflow/compiler/xla/service/dot_as_convolution_util.h
+++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h
@@ -62,6 +62,12 @@
     const DotGeneralAsConvolutionDimsInfo& dot_dnums,
     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
 
+// Check if a spatial dim is parallel batch dimension.
+// A parallel batch dimension in DotGeneral is represented as a spatial
+// dimension with window size B (batch dimension size), stride B - 1, and base
+// dilation B.
+bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
+
 }  // namespace dot_as_convolution_util
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index fa28b6f..76014c8 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -3149,6 +3149,72 @@
   auto aligned_lhs_sharding =
       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
 
+  // Handling cases where all the partitioned dimensions are parallel
+  // dimensions.
+  int64 lhs_parallel_dim_partitions = 1;
+  int64 rhs_parallel_dim_partitions = 1;
+  std::vector<int64> parallel_spatial_dims;
+  for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
+    int64 lhs_dim = dnums.input_spatial_dimensions(i);
+    int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
+    const auto& wd = hlo->window().dimensions(i);
+    int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
+    // Only non reversal window is supported right now.
+    if (!wd.window_reversal() &&
+        dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
+      parallel_spatial_dims.emplace_back(i);
+      lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
+      rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
+    }
+  }
+  bool lhs_partition_dims_are_parallel =
+      (lhs_parallel_dim_partitions == num_partitions_);
+  bool rhs_partition_dims_are_parallel =
+      (rhs_parallel_dim_partitions == num_partitions_);
+
+  // If there is a parallel dim and all the partitioned dimensions are parallel
+  // dimensions in either LHS or RHS, simply create partitioned convolutions.
+  if (!parallel_spatial_dims.empty() &&
+      (lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
+    // Reshard LHS or RHS to partition at parallel dimensions as the other
+    // operand.
+    if (lhs_partition_dims_are_parallel) {
+      rhs = rhs.Reshard(aligned_rhs_sharding);
+    } else {
+      lhs = lhs.Reshard(aligned_lhs_sharding);
+    }
+    auto lhs_shard_shape =
+        MakePartitionedShape(lhs.base_shape(), lhs.sharding());
+    auto rhs_shard_shape =
+        MakePartitionedShape(rhs.base_shape(), rhs.sharding());
+    // Update convolution window.
+    auto new_window = hlo->window();
+    for (const auto& spatial_dim : parallel_spatial_dims) {
+      auto wd = new_window.mutable_dimensions(spatial_dim);
+      wd->set_size(lhs_shard_shape.dimensions(
+          dnums.input_spatial_dimensions(spatial_dim)));
+      wd->set_stride(std::max<int64>(1, wd->size() - 1));
+      wd->set_base_dilation(wd->size());
+    }
+    TF_ASSIGN_OR_RETURN(
+        Shape sharded_conv_shape,
+        ShapeInference::InferConvolveShape(
+            lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
+            hlo->batch_group_count(), new_window, dnums));
+    *sharded_conv_shape.mutable_layout() = hlo->shape().layout();
+    SetPartitionedHlo(hlo, [&]() {
+      auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
+          sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
+          hlo->batch_group_count(), new_window, dnums,
+          hlo->precision_config()));
+      sharded_conv->set_sharding(hlo->sharding());
+      return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
+          .Reshard(hlo->sharding())
+          .hlo();
+    });
+    return Status::OK();
+  }
+
   // Handling cases where both operands' shardings are aligned. We check that
   // the LHS batch dimension is not partitioned because it is mapped to the
   // output feature dimension in aligned_rhs_sharding, which are not the same
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
index 3354a9c..7c4d816 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
@@ -877,5 +877,13 @@
       output_shape, hlo, start_indices, limit_indices, strides));
 }
 
+// Check if a dimension is sharded.
+int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
+  if (sharding.IsTileMaximal()) {
+    return 1;
+  }
+  return sharding.tile_assignment().dim(dim);
+}
+
 }  // namespace spmd
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
index 5f24566..8389c2f 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
@@ -262,6 +262,9 @@
 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
                             int64 slice_dim, int64 k);
 
+// Check if a dimension is sharded.
+int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
+
 }  // namespace spmd
 }  // namespace xla