[XLA:SPMD] Support partition convlution input in both spatial dimension and
input feature dimension.

PiperOrigin-RevId: 332589884
Change-Id: If4a2e0802922829e2d0178b6a888fc56c22b2d04
diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
index 81419c5..f2d996c 100644
--- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
@@ -40,8 +40,12 @@
 // Partition convolution with batch group count.
 StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo,
+    int64 num_partitions, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
   if (original_hlo->batch_group_count() == 1 ||
       original_hlo->batch_group_count() < num_partitions) {
@@ -115,21 +119,10 @@
   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
       lhs.sharding(), lhs_to_output_indices);
 
-  // Get LHS and RHS sharded shape.
-  auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
-  auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
-  const int64 batch_group_count =
-      CeilOfRatio(original_hlo->batch_group_count(), num_partitions);
   // Create partitioned convolution.
   TF_ASSIGN_OR_RETURN(
-      Shape sharded_conv_shape,
-      ShapeInference::InferConvolveShape(
-          lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
-          batch_group_count, conv_window, dnums));
-  auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
-      sharded_conv_shape, lhs.hlo(), rhs.hlo(),
-      original_hlo->feature_group_count(), batch_group_count, conv_window,
-      dnums, original_hlo->precision_config()));
+      auto sharded_conv,
+      create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
   sharded_conv->set_sharding(aligned_output_sharding);
   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
       .Reshard(output_sharding)
@@ -139,8 +132,12 @@
 // Partition convolution with feature group count.
 StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo,
+    int64 num_partitions, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
   if (original_hlo->feature_group_count() == 1 ||
       original_hlo->feature_group_count() < num_partitions) {
@@ -215,20 +212,9 @@
   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
       lhs.sharding(), lhs_to_output_indices);
 
-  auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
-  auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
-  int64 feature_group_count =
-      CeilOfRatio(original_hlo->feature_group_count(), num_partitions);
-
   TF_ASSIGN_OR_RETURN(
-      Shape sharded_conv_shape,
-      ShapeInference::InferConvolveShape(
-          lhs_shard_shape, rhs_shard_shape, feature_group_count,
-          original_hlo->batch_group_count(), conv_window, dnums));
-  auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
-      sharded_conv_shape, lhs.hlo(), rhs.hlo(), feature_group_count,
-      original_hlo->batch_group_count(), conv_window, dnums,
-      original_hlo->precision_config()));
+      auto sharded_conv,
+      create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
   sharded_conv->set_sharding(aligned_output_sharding);
   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
       .Reshard(output_sharding)
@@ -240,9 +226,12 @@
 StatusOr<HloInstruction*>
 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, HloInstruction* partition_id,
-    HloModule* module, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo,
+    HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
                !rhs.sharding().IsTileMaximal());
@@ -491,10 +480,9 @@
     rhs_with_halo = *concat;
   }
 
-  auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
-      output_base_shape, conv_lhs, rhs_with_halo,
-      original_hlo->feature_group_count(), original_hlo->batch_group_count(),
-      new_window, dnums, original_hlo->precision_config()));
+  TF_ASSIGN_OR_RETURN(
+      auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
+
   auto ar = collective_ops_creator.create_cross_partition_all_reduce(
       b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
       (*lhs.state().next_channel_id)++);
@@ -509,9 +497,12 @@
 StatusOr<HloInstruction*>
 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, HloInstruction* partition_id,
-    HloModule* module, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo,
+    HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
                !rhs.sharding().IsTileMaximal());
@@ -583,7 +574,6 @@
     rhs =
         rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
   }
-
   // Reshard LHS by exchanging halo such that each shard computes the partial
   // sum of the full shape result, and add AllReduce.
   //
@@ -701,11 +691,8 @@
     lhs_with_halo = *concat;
   }
 
-  auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
-      output_base_shape, lhs_with_halo, rhs.hlo(),
-      original_hlo->feature_group_count(), original_hlo->batch_group_count(),
-      new_window, original_hlo->convolution_dimension_numbers(),
-      original_hlo->precision_config()));
+  TF_ASSIGN_OR_RETURN(
+      auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
   auto ar =
       lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
           b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@@ -720,8 +707,11 @@
 // RHS.
 StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
   const auto& dnums = original_hlo->convolution_dimension_numbers();
   TF_RET_CHECK(!output_sharding.IsTileMaximal());
@@ -772,19 +762,13 @@
         resharded_operand_and_window->shard_window.dimensions(
             dnums.input_spatial_dimensions(i));
   }
+
   TF_ASSIGN_OR_RETURN(
-      Shape sharded_conv_shape,
-      ShapeInference::InferConvolveShape(
-          resharded_operand_and_window->sharded_input->shape(),
-          rhs.hlo()->shape(), original_hlo->feature_group_count(),
-          original_hlo->batch_group_count(), new_window, dnums));
+      auto sharded_conv,
+      create_sharded_conv(resharded_operand_and_window->sharded_input,
+                          rhs.hlo(), b, new_window));
+
   auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
-  *sharded_conv_shape.mutable_layout() = shard_shape.layout();
-  auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
-      sharded_conv_shape, resharded_operand_and_window->sharded_input,
-      rhs.hlo(), original_hlo->feature_group_count(),
-      original_hlo->batch_group_count(), new_window, dnums,
-      original_hlo->precision_config()));
   if (!resharded_operand_and_window->dynamic_slice_index_on_output
            .has_value()) {
     CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
@@ -799,29 +783,34 @@
 // Partition convolution with only one kind of dims partitioned.
 StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
-    const HloSharding& output_sharding, const Window& conv_window,
-    HloInstruction* original_hlo, int64 num_partitions,
-    const SpmdPartitionerOptions& options, HloInstruction* partition_id,
-    HloModule* module, SpmdBuilder* b) {
+    const HloSharding& output_sharding,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
+    const Window& conv_window, HloInstruction* original_hlo,
+    int64 num_partitions, const SpmdPartitionerOptions& options,
+    HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
 
   // Case 1: Handle depthwise convolution with batch group count or
   // feature group count.
   if (original_hlo->batch_group_count() > 1) {
-    TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
-                        PartitionConvolutionWithBatchGroupCount(
-                            lhs, rhs, output_base_shape, output_sharding,
-                            conv_window, original_hlo, num_partitions, b));
+    TF_ASSIGN_OR_RETURN(
+        auto parallel_partitioned_conv,
+        PartitionConvolutionWithBatchGroupCount(
+            lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+            conv_window, original_hlo, num_partitions, b));
     if (parallel_partitioned_conv) {
       return parallel_partitioned_conv;
     }
   }
 
   if (original_hlo->feature_group_count() > 1) {
-    TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
-                        PartitionConvolutionWithFeatureGroupCount(
-                            lhs, rhs, output_base_shape, output_sharding,
-                            conv_window, original_hlo, num_partitions, b));
+    TF_ASSIGN_OR_RETURN(
+        auto parallel_partitioned_conv,
+        PartitionConvolutionWithFeatureGroupCount(
+            lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+            conv_window, original_hlo, num_partitions, b));
     if (parallel_partitioned_conv) {
       return parallel_partitioned_conv;
     }
@@ -837,8 +826,8 @@
       TF_ASSIGN_OR_RETURN(
           auto partitioned_conv,
           PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
-              lhs, rhs, output_base_shape, output_sharding, conv_window,
-              original_hlo, partition_id, module, b));
+              lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+              conv_window, original_hlo, partition_id, module, b));
       if (partitioned_conv) {
         return partitioned_conv;
       }
@@ -846,8 +835,8 @@
       TF_ASSIGN_OR_RETURN(
           auto partitioned_conv,
           PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
-              lhs, rhs, output_base_shape, output_sharding, conv_window,
-              original_hlo, partition_id, module, b));
+              lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+              conv_window, original_hlo, partition_id, module, b));
 
       if (partitioned_conv) {
         return partitioned_conv;
@@ -860,7 +849,7 @@
     TF_ASSIGN_OR_RETURN(auto partitioned_conv,
                         PartitionConvolutionTiledOutput(
                             lhs, rhs, output_base_shape, output_sharding,
-                            conv_window, original_hlo, b));
+                            create_sharded_conv, conv_window, original_hlo, b));
 
     if (partitioned_conv) {
       return partitioned_conv;
@@ -869,22 +858,92 @@
   return nullptr;
 }
 
+StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
+    const HloInstruction& conv,
+    const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
+    HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
+    const Window& conv_window) {
+  CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
+  const auto& conv_dnums = conv.convolution_dimension_numbers();
+  auto window = conv.window();
+  for (const auto& dim : dot_dnums.batch_dims) {
+    auto wd = window.mutable_dimensions(dim.spatial_dim);
+    wd->set_size(sharded_lhs_hlo->shape().dimensions(
+        conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
+    wd->set_stride(std::max<int64>(1, wd->size() - 1));
+    wd->set_base_dilation(wd->size());
+  }
+  for (const auto& dim : dot_dnums.contracting_dims) {
+    if (dim.spatial_dim < 0) {
+      continue;
+    }
+    auto wd = window.mutable_dimensions(dim.spatial_dim);
+    wd->set_size(sharded_lhs_hlo->shape().dimensions(
+        conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
+  }
+  for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
+    if (dim.spatial_dim < 0) {
+      continue;
+    }
+    auto wd = window.mutable_dimensions(dim.spatial_dim);
+    wd->set_size(sharded_rhs_hlo->shape().dimensions(
+        conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
+    wd->set_padding_high(wd->size() - 1);
+    wd->set_padding_low(wd->size() - 1);
+  }
+
+  for (const auto& dim : dot_dnums.conv_spatial_dims) {
+    auto wd = window.mutable_dimensions(dim.spatial_dim);
+    const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
+    wd->set_size(new_window_dimension.size());
+    wd->set_padding_high(new_window_dimension.padding_high());
+    wd->set_padding_low(new_window_dimension.padding_low());
+  }
+
+  int64 feature_group_count = conv.feature_group_count();
+  if (feature_group_count > 1) {
+    feature_group_count = sharded_lhs_hlo->shape().dimensions(
+                              conv_dnums.input_feature_dimension()) /
+                          sharded_rhs_hlo->shape().dimensions(
+                              conv_dnums.kernel_input_feature_dimension());
+  }
+
+  int64 batch_group_count = conv.batch_group_count();
+  if (batch_group_count > 1) {
+    batch_group_count =
+        sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      Shape sharded_conv_shape,
+      ShapeInference::InferConvolveShape(
+          sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
+          feature_group_count, batch_group_count, window, conv_dnums));
+  *sharded_conv_shape.mutable_layout() = conv.shape().layout();
+  return HloInstruction::CreateConvolve(
+      sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
+      batch_group_count, window, conv_dnums, conv.precision_config());
+}
+
 }  // namespace
 
 // Partition convolution.
 StatusOr<HloInstruction*> PartitionConvolution(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
     const Window& conv_window, HloInstruction* original_hlo,
     int64 num_partitions, const SpmdPartitionerOptions& options,
     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
 
-  TF_ASSIGN_OR_RETURN(
-      auto try_partitioned_conv,
-      PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding,
-                                   conv_window, original_hlo, num_partitions,
-                                   options, partition_id, module, b));
+  TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
+                      PartitionConvolutionBaseCase(
+                          lhs, rhs, output_base_shape, output_sharding,
+                          create_sharded_conv, conv_window, original_hlo,
+                          num_partitions, options, partition_id, module, b));
   if (try_partitioned_conv) {
     return try_partitioned_conv;
   }
@@ -932,13 +991,22 @@
   }
   auto create_sharded_conv =
       [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
-          spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
-    TF_ASSIGN_OR_RETURN(
-        auto sharded_conv,
-        dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
-            *hlo, dims_info, lhs_hlo, rhs_hlo));
-    return b->AddInstruction(std::move(sharded_conv));
+          spmd::SpmdBuilder* b,
+          const Window& conv_window) -> StatusOr<HloInstruction*> {
+    if (dims_info.conv_spatial_dims.empty()) {
+      TF_ASSIGN_OR_RETURN(
+          auto sharded_conv,
+          dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
+              *hlo, dims_info, lhs_hlo, rhs_hlo));
+      return b->AddInstruction(std::move(sharded_conv));
+    } else {
+      TF_ASSIGN_OR_RETURN(auto sharded_conv,
+                          CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
+                                                       rhs_hlo, conv_window));
+      return b->AddInstruction(std::move(sharded_conv));
+    }
   };
+
   return HandleDotHelper(hlo, mapping, create_sharded_conv);
 }
 
diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.h b/tensorflow/compiler/xla/service/spmd/convolution_handler.h
index dced14a..2d929da 100644
--- a/tensorflow/compiler/xla/service/spmd/convolution_handler.h
+++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.h
@@ -29,6 +29,9 @@
 StatusOr<HloInstruction*> PartitionConvolution(
     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
+    const std::function<StatusOr<HloInstruction*>(
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_conv,
     const Window& conv_window, HloInstruction* original_hlo,
     int64 num_partitions, const SpmdPartitionerOptions& options,
     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 25c21ba..65a7b3f 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -19,8 +19,10 @@
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
 #include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -29,6 +31,7 @@
 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/platform/numbers.h"
@@ -72,8 +75,9 @@
     mapping.rhs_non_contracting_dims.back().rhs = i;
     mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
   }
-  auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r,
-                                SpmdBuilder* b) -> StatusOr<HloInstruction*> {
+  auto create_sharded_dot =
+      [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
+          const Window& conv_window) -> StatusOr<HloInstruction*> {
     TF_ASSIGN_OR_RETURN(
         auto sharded_dot_shape,
         ShapeInference::InferDotOpShape(l->shape(), r->shape(),
@@ -92,11 +96,13 @@
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
     int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
-    int64 rhs_batch_partitions, int64 output_batch_partitions,
-    int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
-    int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
+    int64 lhs_batch_partitions, int64 rhs_batch_partitions,
+    int64 output_batch_partitions, int64 lhs_contracting_partitions,
+    int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
+    int64 rhs_non_contracting_partitions,
     int64 output_lhs_non_contracting_partitions,
     int64 output_rhs_non_contracting_partitions,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
@@ -170,7 +176,8 @@
   if (lhs_batch_partitions == rhs_batch_partitions &&
       rhs_batch_partitions == num_partitions &&
       lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
-    TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+    TF_ASSIGN_OR_RETURN(
+        auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
     dot->set_sharding(*lhs_sharding_transposed_to_match_output);
     return PartitionedHlo(dot, output_base_shape, lhs.state())
         .Reshard(output_sharding)
@@ -196,7 +203,8 @@
       }
       auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
       TF_ASSIGN_OR_RETURN(
-          auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
+          auto dot,
+          create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
       return dot;
     }
     // RHS and output are batch partitioned in the same way.
@@ -212,7 +220,8 @@
       }
       auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
       TF_ASSIGN_OR_RETURN(
-          auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
+          auto dot,
+          create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
       return dot;
     }
     return nullptr;
@@ -310,8 +319,8 @@
         dot_rhs = slice;
       }
     }
-    TF_ASSIGN_OR_RETURN(auto dot,
-                        create_sharded_dot(dot_lhs, dot_rhs, &body_b));
+    TF_ASSIGN_OR_RETURN(
+        auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
     if (windowed_at_contracting_dims) {
       // Accumulate the partial output to the result buffer.
       o = body_b.AddInstruction(
@@ -465,7 +474,8 @@
       rhs =
           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
     }
-    TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+    TF_ASSIGN_OR_RETURN(
+        auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
     auto ar =
         lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
             b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@@ -481,8 +491,8 @@
       output_lhs_non_contracting_partitions == num_partitions &&
       lhs_sharding_transposed_to_match_output == output_sharding) {
     auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
-    TF_ASSIGN_OR_RETURN(auto dot,
-                        create_sharded_dot(lhs.hlo(), rhs_replicated, b));
+    TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
+                                                     b, conv_window));
     return dot;
   }
 
@@ -491,8 +501,8 @@
       output_rhs_non_contracting_partitions == num_partitions &&
       rhs_sharding_transposed_to_match_output == output_sharding) {
     auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
-    TF_ASSIGN_OR_RETURN(auto dot,
-                        create_sharded_dot(lhs_replicated, rhs.hlo(), b));
+    TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
+                                                     b, conv_window));
     return dot;
   }
 
@@ -503,8 +513,9 @@
           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
       auto resharded_rhs =
           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
-      TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
-                                                       resharded_rhs.hlo(), b));
+      TF_ASSIGN_OR_RETURN(
+          auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
+                                       b, conv_window));
       return dot;
     }
     // Output is partitioned along LHS non-contracting dimensions.
@@ -513,8 +524,8 @@
           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
       auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
       TF_ASSIGN_OR_RETURN(
-          auto dot,
-          create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
+          auto dot, create_sharded_dot(resharded_lhs.hlo(),
+                                       replicated_rhs.hlo(), b, conv_window));
       return dot;
     }
     // Output is partitioned along RHS non-contracting dimensions.
@@ -522,8 +533,9 @@
       auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
       auto resharded_rhs =
           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
-      TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
-                                                       resharded_rhs.hlo(), b));
+      TF_ASSIGN_OR_RETURN(
+          auto dot, create_sharded_dot(replicated_lhs.hlo(),
+                                       resharded_rhs.hlo(), b, conv_window));
       return dot;
     }
   }
@@ -566,7 +578,8 @@
       rhs =
           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
     }
-    TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+    TF_ASSIGN_OR_RETURN(
+        auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
     return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
         b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
         (*lhs.state().next_channel_id)++);
@@ -579,8 +592,9 @@
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
     int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
         windowed_dot_general_loops);
@@ -592,8 +606,9 @@
     int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
     int64 rhs_non_contracting_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     bool require_matching_devices_to_group,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -808,8 +823,8 @@
                    GetPerGroupBaseShape(output_grouped, output_base_shape),
                    output_grouped.sharding, dims_mapping,
                    num_partitions / output_grouped.device_groups.size(),
-                   create_sharded_dot, module, original_hlo, options, b,
-                   windowed_dot_general_loops));
+                   create_sharded_dot, conv_window, module, original_hlo,
+                   options, b, windowed_dot_general_loops));
   dot->set_sharding(UngroupSharding(output_grouped));
   return PartitionedHlo(dot, output_base_shape, lhs.state())
       .Reshard(output_sharding)
@@ -826,8 +841,9 @@
     const Shape& output_base_shape, const HloSharding& output_sharding,
     const DotConvDimsMapping& dims_mapping, int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     bool require_matching_devices_to_group,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -952,8 +968,8 @@
                    GetPerGroupBaseShape(output_grouped, output_base_shape),
                    output_grouped.sharding, dims_mapping,
                    num_partitions / matching_grouped.device_groups.size(),
-                   create_sharded_dot, module, original_hlo, options, b,
-                   windowed_dot_general_loops));
+                   create_sharded_dot, conv_window, module, original_hlo,
+                   options, b, windowed_dot_general_loops));
   return dot;
 }
 
@@ -966,8 +982,9 @@
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
     int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     bool require_matching_devices_to_group,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -1090,10 +1107,9 @@
           PartitionedHlo(rhs.hlo(),
                          GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
                          inner_state),
-          MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
-          inner_output_sharding, dims_mapping, num_partitions / group_count,
-          create_sharded_dot, module, original_hlo, options, b,
-          windowed_dot_general_loops));
+          output_base_shape, inner_output_sharding, dims_mapping,
+          num_partitions / group_count, create_sharded_dot, conv_window, module,
+          original_hlo, options, b, windowed_dot_general_loops));
   if (!dot) {
     return nullptr;
   }
@@ -1124,8 +1140,9 @@
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
     int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     bool require_matching_devices_to_group,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -1180,35 +1197,25 @@
 
   // Try partition the purely spatially-partitioned convolution with convolution
   // spatial dimension partitioned or depthwise parallel dimension partitioned.
-  if (!dims_mapping.conv_spatial_dims.empty() &&
+  bool is_conv_spatial_dim_partitioned =
       (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
-       output_conv_spatial_partitions > 1 ||
-       original_hlo->batch_group_count() > 1 ||
-       original_hlo->feature_group_count() > 1)) {
-    const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
-    auto window = original_hlo->window();
-
-    // TODO(wangtao): remove this hack by passing create_sharded_conv to
-    // PartitionConv.
-    // Update convolution window when it is in the recursive call for
-    // batch_dims.
-    if (original_hlo->batch_group_count() == 1 &&
-        original_hlo->feature_group_count() == 1 &&
-        !ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
-      for (const auto& dim : dims_mapping.batch_dims) {
-        auto wd = window.mutable_dimensions(dim.spatial);
-        wd->set_size(lhs.hlo()->shape().dimensions(
-            conv_dnums.input_spatial_dimensions(dim.spatial)));
-        wd->set_stride(std::max<int64>(1, wd->size() - 1));
-        wd->set_base_dilation(wd->size());
-      }
-    }
-
+       output_conv_spatial_partitions > 1);
+  bool is_conv_batch_or_contracting_dim_partitioned =
+      (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
+       output_batch_partitions > 1 ||
+       (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
+  if ((!dims_mapping.conv_spatial_dims.empty() &&
+       is_conv_spatial_dim_partitioned &&
+       !is_conv_batch_or_contracting_dim_partitioned) ||
+      (original_hlo->opcode() == HloOpcode::kConvolution &&
+       (original_hlo->batch_group_count() > 1 ||
+        original_hlo->feature_group_count() > 1))) {
     TF_ASSIGN_OR_RETURN(
         auto partitioned_conv,
         PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
-                             dims_mapping, window, original_hlo, num_partitions,
-                             options, lhs.state().partition_id, module, b));
+                             dims_mapping, create_sharded_dot, conv_window,
+                             original_hlo, num_partitions, options,
+                             lhs.state().partition_id, module, b));
 
     if (partitioned_conv) {
       return partitioned_conv;
@@ -1219,7 +1226,7 @@
       auto try_partitioned_dot,
       PartitionBaseCase(
           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
-          num_partitions, create_sharded_dot, module, original_hlo,
+          num_partitions, create_sharded_dot, conv_window, module, original_hlo,
           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
           lhs_contracting_partitions, rhs_contracting_partitions,
           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@@ -1243,8 +1250,8 @@
             lhs, rhs, output_base_shape, output_sharding, dims_mapping,
             num_partitions, lhs_contracting_partitions,
             rhs_contracting_partitions, lhs_non_contracting_partitions,
-            rhs_non_contracting_partitions, create_sharded_dot, module,
-            original_hlo, require_matching_devices_to_group, options, b,
+            rhs_non_contracting_partitions, create_sharded_dot, conv_window,
+            module, original_hlo, require_matching_devices_to_group, options, b,
             windowed_dot_general_loops));
     if (dot) {
       return dot;
@@ -1268,7 +1275,6 @@
                  ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
              rhs_non_contracting_partitions *
                  ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
-
     TF_ASSIGN_OR_RETURN(
         auto dot,
         PartitionDotGroupOnNonContracting(
@@ -1284,7 +1290,7 @@
             lhs_matching ? output_rhs_non_contracting_partitions
                          : output_lhs_non_contracting_partitions,
             output_base_shape, output_sharding, dims_mapping, num_partitions,
-            create_sharded_dot, module, original_hlo,
+            create_sharded_dot, conv_window, module, original_hlo,
             require_matching_devices_to_group, options, b,
             windowed_dot_general_loops));
     if (dot) {
@@ -1304,15 +1310,15 @@
     }
     if (!matching_dims.empty()) {
       TF_ASSIGN_OR_RETURN(
-          auto dot,
-          PartitionDotGroupOnNonContracting(
-              /*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions,
-              rhs_contracting_partitions, matching_dims,
-              rhs_non_contracting_partitions,
-              output_rhs_non_contracting_partitions, output_base_shape,
-              output_sharding, dims_mapping, num_partitions, create_sharded_dot,
-              module, original_hlo, require_matching_devices_to_group, options,
-              b, windowed_dot_general_loops));
+          auto dot, PartitionDotGroupOnNonContracting(
+                        /*lhs_matching=*/true, lhs, rhs,
+                        lhs_contracting_partitions, rhs_contracting_partitions,
+                        matching_dims, rhs_non_contracting_partitions,
+                        output_rhs_non_contracting_partitions,
+                        output_base_shape, output_sharding, dims_mapping,
+                        num_partitions, create_sharded_dot, conv_window, module,
+                        original_hlo, require_matching_devices_to_group,
+                        options, b, windowed_dot_general_loops));
       if (dot) {
         return dot;
       }
@@ -1331,15 +1337,15 @@
     }
     if (!matching_dims.empty()) {
       TF_ASSIGN_OR_RETURN(
-          auto dot,
-          PartitionDotGroupOnNonContracting(
-              /*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions,
-              lhs_contracting_partitions, matching_dims,
-              lhs_non_contracting_partitions,
-              output_lhs_non_contracting_partitions, output_base_shape,
-              output_sharding, dims_mapping, num_partitions, create_sharded_dot,
-              module, original_hlo, require_matching_devices_to_group, options,
-              b, windowed_dot_general_loops));
+          auto dot, PartitionDotGroupOnNonContracting(
+                        /*lhs_matching=*/false, rhs, lhs,
+                        rhs_contracting_partitions, lhs_contracting_partitions,
+                        matching_dims, lhs_non_contracting_partitions,
+                        output_lhs_non_contracting_partitions,
+                        output_base_shape, output_sharding, dims_mapping,
+                        num_partitions, create_sharded_dot, conv_window, module,
+                        original_hlo, require_matching_devices_to_group,
+                        options, b, windowed_dot_general_loops));
       if (dot) {
         return dot;
       }
@@ -1356,7 +1362,8 @@
             output_lhs_non_contracting_partitions,
             output_rhs_non_contracting_partitions, output_base_shape,
             output_sharding, dims_mapping, num_partitions, create_sharded_dot,
-            module, original_hlo, require_matching_devices_to_group, options, b,
+            conv_window, module, original_hlo,
+            require_matching_devices_to_group, options, b,
             windowed_dot_general_loops));
     if (dot) {
       return dot;
@@ -1374,14 +1381,14 @@
     }
     if (!matching_dims.empty()) {
       TF_ASSIGN_OR_RETURN(
-          auto dot,
-          PartitionDotGroupOnContracting(
-              lhs, rhs, matching_dims, output_batch_partitions,
-              output_lhs_non_contracting_partitions,
-              output_rhs_non_contracting_partitions, output_base_shape,
-              output_sharding, dims_mapping, num_partitions, create_sharded_dot,
-              module, original_hlo, require_matching_devices_to_group, options,
-              b, windowed_dot_general_loops));
+          auto dot, PartitionDotGroupOnContracting(
+                        lhs, rhs, matching_dims, output_batch_partitions,
+                        output_lhs_non_contracting_partitions,
+                        output_rhs_non_contracting_partitions,
+                        output_base_shape, output_sharding, dims_mapping,
+                        num_partitions, create_sharded_dot, conv_window, module,
+                        original_hlo, require_matching_devices_to_group,
+                        options, b, windowed_dot_general_loops));
       if (dot) {
         return dot;
       }
@@ -1401,8 +1408,9 @@
         PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
                      PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
                      output_base_shape, grouped_output.sharding, dims_mapping,
-                     output_sharding.NumTiles(), create_sharded_dot, module,
-                     original_hlo, options, b, windowed_dot_general_loops));
+                     output_sharding.NumTiles(), create_sharded_dot,
+                     conv_window, module, original_hlo, options, b,
+                     windowed_dot_general_loops));
     if (dot) {
       return dot;
     }
@@ -1414,7 +1422,7 @@
       auto dot,
       PartitionBaseCase(
           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
-          num_partitions, create_sharded_dot, module, original_hlo,
+          num_partitions, create_sharded_dot, conv_window, module, original_hlo,
           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
           lhs_contracting_partitions, rhs_contracting_partitions,
           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@@ -1433,8 +1441,9 @@
     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
     int64 num_partitions,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
-    HloModule* module, HloInstruction* original_hlo,
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot,
+    const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
     const SpmdPartitionerOptions& options, SpmdBuilder* b,
     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
         windowed_dot_general_loops) {
@@ -1444,17 +1453,18 @@
     TF_ASSIGN_OR_RETURN(
         auto try_partition,
         PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
-                     num_partitions, create_sharded_dot, module, original_hlo,
-                     require_matching_devices_to_group, options, b,
-                     windowed_dot_general_loops));
+                     num_partitions, create_sharded_dot, conv_window, module,
+                     original_hlo, require_matching_devices_to_group, options,
+                     b, windowed_dot_general_loops));
     if (try_partition) {
       return try_partition;
     }
   }
 
   // Default action.
-  TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
-                                                   rhs.Replicate().hlo(), b));
+  TF_ASSIGN_OR_RETURN(
+      auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
+                                   b, conv_window));
   dot->set_sharding(HloSharding::Replicate());
   return PartitionedHlo(dot, output_base_shape, lhs.state())
       .Reshard(output_sharding)
@@ -1466,14 +1476,20 @@
 Status SpmdPartitioningVisitor::HandleDotHelper(
     HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
     const std::function<StatusOr<HloInstruction*>(
-        HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
+        HloInstruction*, HloInstruction*, SpmdBuilder*,
+        const Window& conv_window)>& create_sharded_dot) {
   auto& lhs = GetPartitionedHlo(hlo->operand(0));
   auto& rhs = GetPartitionedHlo(hlo->operand(1));
+  Window conv_window;
+  if (hlo->opcode() == HloOpcode::kConvolution) {
+    conv_window = hlo->window();
+  }
+
   TF_ASSIGN_OR_RETURN(
       auto partitioned_dot,
       PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
-                   num_partitions_, create_sharded_dot, module_, hlo, options_,
-                   &b_, &windowed_dot_general_loops_));
+                   num_partitions_, create_sharded_dot, conv_window, module_,
+                   hlo, options_, &b_, &windowed_dot_general_loops_));
   SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index b09ea0c..86c1a97 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -407,10 +407,11 @@
   Status HandlePartitionId(HloInstruction* hlo) override;
 
   // Implementation of dot partitioning given DotGeneralDimsMapping.
-  Status HandleDotHelper(
-      HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
-      const std::function<StatusOr<HloInstruction*>(
-          HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
+  Status HandleDotHelper(HloInstruction* hlo,
+                         const DotConvDimsMapping& dims_mapping,
+                         const std::function<StatusOr<HloInstruction*>(
+                             HloInstruction*, HloInstruction*, SpmdBuilder*,
+                             const Window& conv_window)>& create_sharded_dot);
 
   // Common handle for elementwise HLOs.
   Status HandleElementwise(HloInstruction* hlo);
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index f3bd971..43e6dbf 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -5893,6 +5893,50 @@
                           op::Shape("f32[1,1,128,256]")));
 }
 
+TEST_F(SpmdPartitioningTest,
+       ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %lhs = f32[8,210,210,12] parameter(0)
+  %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
+    sharding={devices=[1,2,1,2]0,1,2,3}
+  %rhs = f32[3,3,12,32] parameter(1)
+  %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
+    sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
+  ROOT %conv = f32[8,210,210,32] convolution(
+    f32[8,210,210,12] %lhs.copy,
+    f32[3,3,12,32] %rhs.copy),
+    window={size=3x3 pad=1_1x1_1},
+    dim_labels=b01f_01io->b01f,
+    sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  auto root = module->entry_computation()->root_instruction();
+  auto lhs = AllOf(
+      op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
+                                op::Constant(), op::Reshape())),
+      op::Shape("f32[8,105,210,6]"));
+  auto left_halo =
+      AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
+  auto right_halo =
+      AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
+  auto exchanged_lhs = AllOf(
+      op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
+                 op::Broadcast(_)),
+      op::Shape("f32[8,107,210,6]"));
+  auto rhs = AllOf(
+      op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
+                                op::Reshape(), op::Constant())),
+      op::Shape("f32[3,3,6,32]"));
+  EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
+                              exchanged_lhs, op::CollectivePermute(rhs))),
+                          op::Shape("f32[8,105,210,32]")));
+}
+
 }  // namespace
 }  // namespace spmd
 }  // namespace xla