[XLA:SPMD] Support partition convlution input in both spatial dimension and
input feature dimension.
PiperOrigin-RevId: 332589884
Change-Id: If4a2e0802922829e2d0178b6a888fc56c22b2d04
diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
index 81419c5..f2d996c 100644
--- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc
@@ -40,8 +40,12 @@
// Partition convolution with batch group count.
StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo,
+ int64 num_partitions, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
if (original_hlo->batch_group_count() == 1 ||
original_hlo->batch_group_count() < num_partitions) {
@@ -115,21 +119,10 @@
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
lhs.sharding(), lhs_to_output_indices);
- // Get LHS and RHS sharded shape.
- auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
- auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
- const int64 batch_group_count =
- CeilOfRatio(original_hlo->batch_group_count(), num_partitions);
// Create partitioned convolution.
TF_ASSIGN_OR_RETURN(
- Shape sharded_conv_shape,
- ShapeInference::InferConvolveShape(
- lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
- batch_group_count, conv_window, dnums));
- auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
- sharded_conv_shape, lhs.hlo(), rhs.hlo(),
- original_hlo->feature_group_count(), batch_group_count, conv_window,
- dnums, original_hlo->precision_config()));
+ auto sharded_conv,
+ create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
sharded_conv->set_sharding(aligned_output_sharding);
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
@@ -139,8 +132,12 @@
// Partition convolution with feature group count.
StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo,
+ int64 num_partitions, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
if (original_hlo->feature_group_count() == 1 ||
original_hlo->feature_group_count() < num_partitions) {
@@ -215,20 +212,9 @@
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
lhs.sharding(), lhs_to_output_indices);
- auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
- auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
- int64 feature_group_count =
- CeilOfRatio(original_hlo->feature_group_count(), num_partitions);
-
TF_ASSIGN_OR_RETURN(
- Shape sharded_conv_shape,
- ShapeInference::InferConvolveShape(
- lhs_shard_shape, rhs_shard_shape, feature_group_count,
- original_hlo->batch_group_count(), conv_window, dnums));
- auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
- sharded_conv_shape, lhs.hlo(), rhs.hlo(), feature_group_count,
- original_hlo->batch_group_count(), conv_window, dnums,
- original_hlo->precision_config()));
+ auto sharded_conv,
+ create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
sharded_conv->set_sharding(aligned_output_sharding);
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
@@ -240,9 +226,12 @@
StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, HloInstruction* partition_id,
- HloModule* module, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo,
+ HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
@@ -491,10 +480,9 @@
rhs_with_halo = *concat;
}
- auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
- output_base_shape, conv_lhs, rhs_with_halo,
- original_hlo->feature_group_count(), original_hlo->batch_group_count(),
- new_window, dnums, original_hlo->precision_config()));
+ TF_ASSIGN_OR_RETURN(
+ auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
+
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
(*lhs.state().next_channel_id)++);
@@ -509,9 +497,12 @@
StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, HloInstruction* partition_id,
- HloModule* module, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo,
+ HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
@@ -583,7 +574,6 @@
rhs =
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
}
-
// Reshard LHS by exchanging halo such that each shard computes the partial
// sum of the full shape result, and add AllReduce.
//
@@ -701,11 +691,8 @@
lhs_with_halo = *concat;
}
- auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
- output_base_shape, lhs_with_halo, rhs.hlo(),
- original_hlo->feature_group_count(), original_hlo->batch_group_count(),
- new_window, original_hlo->convolution_dimension_numbers(),
- original_hlo->precision_config()));
+ TF_ASSIGN_OR_RETURN(
+ auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@@ -720,8 +707,11 @@
// RHS.
StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
const auto& dnums = original_hlo->convolution_dimension_numbers();
TF_RET_CHECK(!output_sharding.IsTileMaximal());
@@ -772,19 +762,13 @@
resharded_operand_and_window->shard_window.dimensions(
dnums.input_spatial_dimensions(i));
}
+
TF_ASSIGN_OR_RETURN(
- Shape sharded_conv_shape,
- ShapeInference::InferConvolveShape(
- resharded_operand_and_window->sharded_input->shape(),
- rhs.hlo()->shape(), original_hlo->feature_group_count(),
- original_hlo->batch_group_count(), new_window, dnums));
+ auto sharded_conv,
+ create_sharded_conv(resharded_operand_and_window->sharded_input,
+ rhs.hlo(), b, new_window));
+
auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
- *sharded_conv_shape.mutable_layout() = shard_shape.layout();
- auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
- sharded_conv_shape, resharded_operand_and_window->sharded_input,
- rhs.hlo(), original_hlo->feature_group_count(),
- original_hlo->batch_group_count(), new_window, dnums,
- original_hlo->precision_config()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
@@ -799,29 +783,34 @@
// Partition convolution with only one kind of dims partitioned.
StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
- const HloSharding& output_sharding, const Window& conv_window,
- HloInstruction* original_hlo, int64 num_partitions,
- const SpmdPartitionerOptions& options, HloInstruction* partition_id,
- HloModule* module, SpmdBuilder* b) {
+ const HloSharding& output_sharding,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
+ const Window& conv_window, HloInstruction* original_hlo,
+ int64 num_partitions, const SpmdPartitionerOptions& options,
+ HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
// Case 1: Handle depthwise convolution with batch group count or
// feature group count.
if (original_hlo->batch_group_count() > 1) {
- TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
- PartitionConvolutionWithBatchGroupCount(
- lhs, rhs, output_base_shape, output_sharding,
- conv_window, original_hlo, num_partitions, b));
+ TF_ASSIGN_OR_RETURN(
+ auto parallel_partitioned_conv,
+ PartitionConvolutionWithBatchGroupCount(
+ lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+ conv_window, original_hlo, num_partitions, b));
if (parallel_partitioned_conv) {
return parallel_partitioned_conv;
}
}
if (original_hlo->feature_group_count() > 1) {
- TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
- PartitionConvolutionWithFeatureGroupCount(
- lhs, rhs, output_base_shape, output_sharding,
- conv_window, original_hlo, num_partitions, b));
+ TF_ASSIGN_OR_RETURN(
+ auto parallel_partitioned_conv,
+ PartitionConvolutionWithFeatureGroupCount(
+ lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+ conv_window, original_hlo, num_partitions, b));
if (parallel_partitioned_conv) {
return parallel_partitioned_conv;
}
@@ -837,8 +826,8 @@
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
- lhs, rhs, output_base_shape, output_sharding, conv_window,
- original_hlo, partition_id, module, b));
+ lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+ conv_window, original_hlo, partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
}
@@ -846,8 +835,8 @@
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
- lhs, rhs, output_base_shape, output_sharding, conv_window,
- original_hlo, partition_id, module, b));
+ lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
+ conv_window, original_hlo, partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
@@ -860,7 +849,7 @@
TF_ASSIGN_OR_RETURN(auto partitioned_conv,
PartitionConvolutionTiledOutput(
lhs, rhs, output_base_shape, output_sharding,
- conv_window, original_hlo, b));
+ create_sharded_conv, conv_window, original_hlo, b));
if (partitioned_conv) {
return partitioned_conv;
@@ -869,22 +858,92 @@
return nullptr;
}
+StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
+ const HloInstruction& conv,
+ const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
+ HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
+ const Window& conv_window) {
+ CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
+ const auto& conv_dnums = conv.convolution_dimension_numbers();
+ auto window = conv.window();
+ for (const auto& dim : dot_dnums.batch_dims) {
+ auto wd = window.mutable_dimensions(dim.spatial_dim);
+ wd->set_size(sharded_lhs_hlo->shape().dimensions(
+ conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
+ wd->set_stride(std::max<int64>(1, wd->size() - 1));
+ wd->set_base_dilation(wd->size());
+ }
+ for (const auto& dim : dot_dnums.contracting_dims) {
+ if (dim.spatial_dim < 0) {
+ continue;
+ }
+ auto wd = window.mutable_dimensions(dim.spatial_dim);
+ wd->set_size(sharded_lhs_hlo->shape().dimensions(
+ conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
+ }
+ for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
+ if (dim.spatial_dim < 0) {
+ continue;
+ }
+ auto wd = window.mutable_dimensions(dim.spatial_dim);
+ wd->set_size(sharded_rhs_hlo->shape().dimensions(
+ conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
+ wd->set_padding_high(wd->size() - 1);
+ wd->set_padding_low(wd->size() - 1);
+ }
+
+ for (const auto& dim : dot_dnums.conv_spatial_dims) {
+ auto wd = window.mutable_dimensions(dim.spatial_dim);
+ const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
+ wd->set_size(new_window_dimension.size());
+ wd->set_padding_high(new_window_dimension.padding_high());
+ wd->set_padding_low(new_window_dimension.padding_low());
+ }
+
+ int64 feature_group_count = conv.feature_group_count();
+ if (feature_group_count > 1) {
+ feature_group_count = sharded_lhs_hlo->shape().dimensions(
+ conv_dnums.input_feature_dimension()) /
+ sharded_rhs_hlo->shape().dimensions(
+ conv_dnums.kernel_input_feature_dimension());
+ }
+
+ int64 batch_group_count = conv.batch_group_count();
+ if (batch_group_count > 1) {
+ batch_group_count =
+ sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ Shape sharded_conv_shape,
+ ShapeInference::InferConvolveShape(
+ sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
+ feature_group_count, batch_group_count, window, conv_dnums));
+ *sharded_conv_shape.mutable_layout() = conv.shape().layout();
+ return HloInstruction::CreateConvolve(
+ sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
+ batch_group_count, window, conv_dnums, conv.precision_config());
+}
+
} // namespace
// Partition convolution.
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
- TF_ASSIGN_OR_RETURN(
- auto try_partitioned_conv,
- PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding,
- conv_window, original_hlo, num_partitions,
- options, partition_id, module, b));
+ TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
+ PartitionConvolutionBaseCase(
+ lhs, rhs, output_base_shape, output_sharding,
+ create_sharded_conv, conv_window, original_hlo,
+ num_partitions, options, partition_id, module, b));
if (try_partitioned_conv) {
return try_partitioned_conv;
}
@@ -932,13 +991,22 @@
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
- spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
- TF_ASSIGN_OR_RETURN(
- auto sharded_conv,
- dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
- *hlo, dims_info, lhs_hlo, rhs_hlo));
- return b->AddInstruction(std::move(sharded_conv));
+ spmd::SpmdBuilder* b,
+ const Window& conv_window) -> StatusOr<HloInstruction*> {
+ if (dims_info.conv_spatial_dims.empty()) {
+ TF_ASSIGN_OR_RETURN(
+ auto sharded_conv,
+ dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
+ *hlo, dims_info, lhs_hlo, rhs_hlo));
+ return b->AddInstruction(std::move(sharded_conv));
+ } else {
+ TF_ASSIGN_OR_RETURN(auto sharded_conv,
+ CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
+ rhs_hlo, conv_window));
+ return b->AddInstruction(std::move(sharded_conv));
+ }
};
+
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}
diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.h b/tensorflow/compiler/xla/service/spmd/convolution_handler.h
index dced14a..2d929da 100644
--- a/tensorflow/compiler/xla/service/spmd/convolution_handler.h
+++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.h
@@ -29,6 +29,9 @@
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 25c21ba..65a7b3f 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -19,8 +19,10 @@
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -29,6 +31,7 @@
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/numbers.h"
@@ -72,8 +75,9 @@
mapping.rhs_non_contracting_dims.back().rhs = i;
mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
}
- auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r,
- SpmdBuilder* b) -> StatusOr<HloInstruction*> {
+ auto create_sharded_dot =
+ [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
+ const Window& conv_window) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_dot_shape,
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
@@ -92,11 +96,13 @@
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
- int64 rhs_batch_partitions, int64 output_batch_partitions,
- int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
- int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
+ int64 lhs_batch_partitions, int64 rhs_batch_partitions,
+ int64 output_batch_partitions, int64 lhs_contracting_partitions,
+ int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
+ int64 rhs_non_contracting_partitions,
int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
@@ -170,7 +176,8 @@
if (lhs_batch_partitions == rhs_batch_partitions &&
rhs_batch_partitions == num_partitions &&
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@@ -196,7 +203,8 @@
}
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
- auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
+ auto dot,
+ create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
return dot;
}
// RHS and output are batch partitioned in the same way.
@@ -212,7 +220,8 @@
}
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
TF_ASSIGN_OR_RETURN(
- auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
+ auto dot,
+ create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
return dot;
}
return nullptr;
@@ -310,8 +319,8 @@
dot_rhs = slice;
}
}
- TF_ASSIGN_OR_RETURN(auto dot,
- create_sharded_dot(dot_lhs, dot_rhs, &body_b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
if (windowed_at_contracting_dims) {
// Accumulate the partial output to the result buffer.
o = body_b.AddInstruction(
@@ -465,7 +474,8 @@
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@@ -481,8 +491,8 @@
output_lhs_non_contracting_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
- TF_ASSIGN_OR_RETURN(auto dot,
- create_sharded_dot(lhs.hlo(), rhs_replicated, b));
+ TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
+ b, conv_window));
return dot;
}
@@ -491,8 +501,8 @@
output_rhs_non_contracting_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
- TF_ASSIGN_OR_RETURN(auto dot,
- create_sharded_dot(lhs_replicated, rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
+ b, conv_window));
return dot;
}
@@ -503,8 +513,9 @@
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
- resharded_rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
+ b, conv_window));
return dot;
}
// Output is partitioned along LHS non-contracting dimensions.
@@ -513,8 +524,8 @@
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
- auto dot,
- create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
+ auto dot, create_sharded_dot(resharded_lhs.hlo(),
+ replicated_rhs.hlo(), b, conv_window));
return dot;
}
// Output is partitioned along RHS non-contracting dimensions.
@@ -522,8 +533,9 @@
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
- resharded_rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(replicated_lhs.hlo(),
+ resharded_rhs.hlo(), b, conv_window));
return dot;
}
}
@@ -566,7 +578,8 @@
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
(*lhs.state().next_channel_id)++);
@@ -579,8 +592,9 @@
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops);
@@ -592,8 +606,9 @@
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
int64 rhs_non_contracting_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -808,8 +823,8 @@
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / output_grouped.device_groups.size(),
- create_sharded_dot, module, original_hlo, options, b,
- windowed_dot_general_loops));
+ create_sharded_dot, conv_window, module, original_hlo,
+ options, b, windowed_dot_general_loops));
dot->set_sharding(UngroupSharding(output_grouped));
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@@ -826,8 +841,9 @@
const Shape& output_base_shape, const HloSharding& output_sharding,
const DotConvDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -952,8 +968,8 @@
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / matching_grouped.device_groups.size(),
- create_sharded_dot, module, original_hlo, options, b,
- windowed_dot_general_loops));
+ create_sharded_dot, conv_window, module, original_hlo,
+ options, b, windowed_dot_general_loops));
return dot;
}
@@ -966,8 +982,9 @@
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -1090,10 +1107,9 @@
PartitionedHlo(rhs.hlo(),
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
inner_state),
- MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
- inner_output_sharding, dims_mapping, num_partitions / group_count,
- create_sharded_dot, module, original_hlo, options, b,
- windowed_dot_general_loops));
+ output_base_shape, inner_output_sharding, dims_mapping,
+ num_partitions / group_count, create_sharded_dot, conv_window, module,
+ original_hlo, options, b, windowed_dot_general_loops));
if (!dot) {
return nullptr;
}
@@ -1124,8 +1140,9 @@
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@@ -1180,35 +1197,25 @@
// Try partition the purely spatially-partitioned convolution with convolution
// spatial dimension partitioned or depthwise parallel dimension partitioned.
- if (!dims_mapping.conv_spatial_dims.empty() &&
+ bool is_conv_spatial_dim_partitioned =
(lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
- output_conv_spatial_partitions > 1 ||
- original_hlo->batch_group_count() > 1 ||
- original_hlo->feature_group_count() > 1)) {
- const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
- auto window = original_hlo->window();
-
- // TODO(wangtao): remove this hack by passing create_sharded_conv to
- // PartitionConv.
- // Update convolution window when it is in the recursive call for
- // batch_dims.
- if (original_hlo->batch_group_count() == 1 &&
- original_hlo->feature_group_count() == 1 &&
- !ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
- for (const auto& dim : dims_mapping.batch_dims) {
- auto wd = window.mutable_dimensions(dim.spatial);
- wd->set_size(lhs.hlo()->shape().dimensions(
- conv_dnums.input_spatial_dimensions(dim.spatial)));
- wd->set_stride(std::max<int64>(1, wd->size() - 1));
- wd->set_base_dilation(wd->size());
- }
- }
-
+ output_conv_spatial_partitions > 1);
+ bool is_conv_batch_or_contracting_dim_partitioned =
+ (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
+ output_batch_partitions > 1 ||
+ (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
+ if ((!dims_mapping.conv_spatial_dims.empty() &&
+ is_conv_spatial_dim_partitioned &&
+ !is_conv_batch_or_contracting_dim_partitioned) ||
+ (original_hlo->opcode() == HloOpcode::kConvolution &&
+ (original_hlo->batch_group_count() > 1 ||
+ original_hlo->feature_group_count() > 1))) {
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
- dims_mapping, window, original_hlo, num_partitions,
- options, lhs.state().partition_id, module, b));
+ dims_mapping, create_sharded_dot, conv_window,
+ original_hlo, num_partitions, options,
+ lhs.state().partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
@@ -1219,7 +1226,7 @@
auto try_partitioned_dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
- num_partitions, create_sharded_dot, module, original_hlo,
+ num_partitions, create_sharded_dot, conv_window, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@@ -1243,8 +1250,8 @@
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, lhs_contracting_partitions,
rhs_contracting_partitions, lhs_non_contracting_partitions,
- rhs_non_contracting_partitions, create_sharded_dot, module,
- original_hlo, require_matching_devices_to_group, options, b,
+ rhs_non_contracting_partitions, create_sharded_dot, conv_window,
+ module, original_hlo, require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
@@ -1268,7 +1275,6 @@
ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
rhs_non_contracting_partitions *
ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
-
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
@@ -1284,7 +1290,7 @@
lhs_matching ? output_rhs_non_contracting_partitions
: output_lhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping, num_partitions,
- create_sharded_dot, module, original_hlo,
+ create_sharded_dot, conv_window, module, original_hlo,
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
@@ -1304,15 +1310,15 @@
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
- auto dot,
- PartitionDotGroupOnNonContracting(
- /*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions,
- rhs_contracting_partitions, matching_dims,
- rhs_non_contracting_partitions,
- output_rhs_non_contracting_partitions, output_base_shape,
- output_sharding, dims_mapping, num_partitions, create_sharded_dot,
- module, original_hlo, require_matching_devices_to_group, options,
- b, windowed_dot_general_loops));
+ auto dot, PartitionDotGroupOnNonContracting(
+ /*lhs_matching=*/true, lhs, rhs,
+ lhs_contracting_partitions, rhs_contracting_partitions,
+ matching_dims, rhs_non_contracting_partitions,
+ output_rhs_non_contracting_partitions,
+ output_base_shape, output_sharding, dims_mapping,
+ num_partitions, create_sharded_dot, conv_window, module,
+ original_hlo, require_matching_devices_to_group,
+ options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@@ -1331,15 +1337,15 @@
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
- auto dot,
- PartitionDotGroupOnNonContracting(
- /*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions,
- lhs_contracting_partitions, matching_dims,
- lhs_non_contracting_partitions,
- output_lhs_non_contracting_partitions, output_base_shape,
- output_sharding, dims_mapping, num_partitions, create_sharded_dot,
- module, original_hlo, require_matching_devices_to_group, options,
- b, windowed_dot_general_loops));
+ auto dot, PartitionDotGroupOnNonContracting(
+ /*lhs_matching=*/false, rhs, lhs,
+ rhs_contracting_partitions, lhs_contracting_partitions,
+ matching_dims, lhs_non_contracting_partitions,
+ output_lhs_non_contracting_partitions,
+ output_base_shape, output_sharding, dims_mapping,
+ num_partitions, create_sharded_dot, conv_window, module,
+ original_hlo, require_matching_devices_to_group,
+ options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@@ -1356,7 +1362,8 @@
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
- module, original_hlo, require_matching_devices_to_group, options, b,
+ conv_window, module, original_hlo,
+ require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
@@ -1374,14 +1381,14 @@
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
- auto dot,
- PartitionDotGroupOnContracting(
- lhs, rhs, matching_dims, output_batch_partitions,
- output_lhs_non_contracting_partitions,
- output_rhs_non_contracting_partitions, output_base_shape,
- output_sharding, dims_mapping, num_partitions, create_sharded_dot,
- module, original_hlo, require_matching_devices_to_group, options,
- b, windowed_dot_general_loops));
+ auto dot, PartitionDotGroupOnContracting(
+ lhs, rhs, matching_dims, output_batch_partitions,
+ output_lhs_non_contracting_partitions,
+ output_rhs_non_contracting_partitions,
+ output_base_shape, output_sharding, dims_mapping,
+ num_partitions, create_sharded_dot, conv_window, module,
+ original_hlo, require_matching_devices_to_group,
+ options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@@ -1401,8 +1408,9 @@
PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
output_base_shape, grouped_output.sharding, dims_mapping,
- output_sharding.NumTiles(), create_sharded_dot, module,
- original_hlo, options, b, windowed_dot_general_loops));
+ output_sharding.NumTiles(), create_sharded_dot,
+ conv_window, module, original_hlo, options, b,
+ windowed_dot_general_loops));
if (dot) {
return dot;
}
@@ -1414,7 +1422,7 @@
auto dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
- num_partitions, create_sharded_dot, module, original_hlo,
+ num_partitions, create_sharded_dot, conv_window, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@@ -1433,8 +1441,9 @@
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
- HloModule* module, HloInstruction* original_hlo,
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot,
+ const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
@@ -1444,17 +1453,18 @@
TF_ASSIGN_OR_RETURN(
auto try_partition,
PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
- num_partitions, create_sharded_dot, module, original_hlo,
- require_matching_devices_to_group, options, b,
- windowed_dot_general_loops));
+ num_partitions, create_sharded_dot, conv_window, module,
+ original_hlo, require_matching_devices_to_group, options,
+ b, windowed_dot_general_loops));
if (try_partition) {
return try_partition;
}
}
// Default action.
- TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
- rhs.Replicate().hlo(), b));
+ TF_ASSIGN_OR_RETURN(
+ auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
+ b, conv_window));
dot->set_sharding(HloSharding::Replicate());
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@@ -1466,14 +1476,20 @@
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot) {
auto& lhs = GetPartitionedHlo(hlo->operand(0));
auto& rhs = GetPartitionedHlo(hlo->operand(1));
+ Window conv_window;
+ if (hlo->opcode() == HloOpcode::kConvolution) {
+ conv_window = hlo->window();
+ }
+
TF_ASSIGN_OR_RETURN(
auto partitioned_dot,
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
- num_partitions_, create_sharded_dot, module_, hlo, options_,
- &b_, &windowed_dot_general_loops_));
+ num_partitions_, create_sharded_dot, conv_window, module_,
+ hlo, options_, &b_, &windowed_dot_general_loops_));
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index b09ea0c..86c1a97 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -407,10 +407,11 @@
Status HandlePartitionId(HloInstruction* hlo) override;
// Implementation of dot partitioning given DotGeneralDimsMapping.
- Status HandleDotHelper(
- HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
- const std::function<StatusOr<HloInstruction*>(
- HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
+ Status HandleDotHelper(HloInstruction* hlo,
+ const DotConvDimsMapping& dims_mapping,
+ const std::function<StatusOr<HloInstruction*>(
+ HloInstruction*, HloInstruction*, SpmdBuilder*,
+ const Window& conv_window)>& create_sharded_dot);
// Common handle for elementwise HLOs.
Status HandleElementwise(HloInstruction* hlo);
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index f3bd971..43e6dbf 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -5893,6 +5893,50 @@
op::Shape("f32[1,1,128,256]")));
}
+TEST_F(SpmdPartitioningTest,
+ ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
+ const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+ %lhs = f32[8,210,210,12] parameter(0)
+ %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
+ sharding={devices=[1,2,1,2]0,1,2,3}
+ %rhs = f32[3,3,12,32] parameter(1)
+ %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
+ sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
+ ROOT %conv = f32[8,210,210,32] convolution(
+ f32[8,210,210,12] %lhs.copy,
+ f32[3,3,12,32] %rhs.copy),
+ window={size=3x3 pad=1_1x1_1},
+ dim_labels=b01f_01io->b01f,
+ sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ PartitionComputation(hlo_string, /*num_devices=*/4));
+ VLOG(1) << module->ToString();
+ auto root = module->entry_computation()->root_instruction();
+ auto lhs = AllOf(
+ op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
+ op::Constant(), op::Reshape())),
+ op::Shape("f32[8,105,210,6]"));
+ auto left_halo =
+ AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
+ auto right_halo =
+ AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
+ auto exchanged_lhs = AllOf(
+ op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
+ op::Broadcast(_)),
+ op::Shape("f32[8,107,210,6]"));
+ auto rhs = AllOf(
+ op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
+ op::Reshape(), op::Constant())),
+ op::Shape("f32[3,3,6,32]"));
+ EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
+ exchanged_lhs, op::CollectivePermute(rhs))),
+ op::Shape("f32[8,105,210,32]")));
+}
+
} // namespace
} // namespace spmd
} // namespace xla