[XLA:SPMD] Support convolution with non contracting spatial dim partitioned at
batch dim.
PiperOrigin-RevId: 321579087
Change-Id: I42f86b9c281e02f8157287653aa30a54c14a0e72
diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
index 576d9d4..4670ce6 100644
--- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
+++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc
@@ -24,6 +24,31 @@
namespace xla {
namespace dot_as_convolution_util {
+bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
+ // A parallel batch dimension in DotGeneral is represented as a
+ // spatial dimension with window size B (batch dimension size),
+ // stride B - 1, and base dilation B.
+ if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
+ ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
+ wd.window_dilation() == 1) ||
+ (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
+ wd.stride() == 1)) &&
+ wd.padding_high() == 0 && wd.padding_low() == 0 &&
+ !wd.window_reversal()) {
+ return true;
+ }
+
+ // Aternative representation of a batch dimension.
+ if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
+ wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
+ wd.window_dilation() == 1 && wd.stride() == lhs_size &&
+ wd.base_dilation() == lhs_size - 1) {
+ return true;
+ }
+
+ return false;
+}
+
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -49,22 +74,7 @@
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
int64 output = conv_dims.output_spatial_dimensions(i);
const auto& wd = conv->window().dimensions(i);
- if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
- ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
- wd.window_dilation() == 1) ||
- (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
- wd.stride() == 1)) &&
- wd.padding_high() == 0 && wd.padding_low() == 0 &&
- !wd.window_reversal()) {
- // A batch dimension in DotGeneral is represented as a spatial dimension
- // with window size B (batch dimension size), stride B - 1, and base
- // dilation B.
- dims.batch_dims.push_back({lhs, rhs, output, i});
- } else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
- wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
- wd.window_dilation() == 1 && wd.stride() == lhs_size &&
- wd.base_dilation() == lhs_size - 1) {
- // Aternative representation of a batch dimension.
+ if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h
index a3e829a..6a7cacf 100644
--- a/tensorflow/compiler/xla/service/dot_as_convolution_util.h
+++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h
@@ -62,6 +62,12 @@
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
+// Check if a spatial dim is parallel batch dimension.
+// A parallel batch dimension in DotGeneral is represented as a spatial
+// dimension with window size B (batch dimension size), stride B - 1, and base
+// dilation B.
+bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
+
} // namespace dot_as_convolution_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index fa28b6f..76014c8 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -3149,6 +3149,72 @@
auto aligned_lhs_sharding =
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
+ // Handling cases where all the partitioned dimensions are parallel
+ // dimensions.
+ int64 lhs_parallel_dim_partitions = 1;
+ int64 rhs_parallel_dim_partitions = 1;
+ std::vector<int64> parallel_spatial_dims;
+ for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
+ int64 lhs_dim = dnums.input_spatial_dimensions(i);
+ int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
+ const auto& wd = hlo->window().dimensions(i);
+ int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
+ // Only non reversal window is supported right now.
+ if (!wd.window_reversal() &&
+ dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
+ parallel_spatial_dims.emplace_back(i);
+ lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
+ rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
+ }
+ }
+ bool lhs_partition_dims_are_parallel =
+ (lhs_parallel_dim_partitions == num_partitions_);
+ bool rhs_partition_dims_are_parallel =
+ (rhs_parallel_dim_partitions == num_partitions_);
+
+ // If there is a parallel dim and all the partitioned dimensions are parallel
+ // dimensions in either LHS or RHS, simply create partitioned convolutions.
+ if (!parallel_spatial_dims.empty() &&
+ (lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
+ // Reshard LHS or RHS to partition at parallel dimensions as the other
+ // operand.
+ if (lhs_partition_dims_are_parallel) {
+ rhs = rhs.Reshard(aligned_rhs_sharding);
+ } else {
+ lhs = lhs.Reshard(aligned_lhs_sharding);
+ }
+ auto lhs_shard_shape =
+ MakePartitionedShape(lhs.base_shape(), lhs.sharding());
+ auto rhs_shard_shape =
+ MakePartitionedShape(rhs.base_shape(), rhs.sharding());
+ // Update convolution window.
+ auto new_window = hlo->window();
+ for (const auto& spatial_dim : parallel_spatial_dims) {
+ auto wd = new_window.mutable_dimensions(spatial_dim);
+ wd->set_size(lhs_shard_shape.dimensions(
+ dnums.input_spatial_dimensions(spatial_dim)));
+ wd->set_stride(std::max<int64>(1, wd->size() - 1));
+ wd->set_base_dilation(wd->size());
+ }
+ TF_ASSIGN_OR_RETURN(
+ Shape sharded_conv_shape,
+ ShapeInference::InferConvolveShape(
+ lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
+ hlo->batch_group_count(), new_window, dnums));
+ *sharded_conv_shape.mutable_layout() = hlo->shape().layout();
+ SetPartitionedHlo(hlo, [&]() {
+ auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
+ sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
+ hlo->batch_group_count(), new_window, dnums,
+ hlo->precision_config()));
+ sharded_conv->set_sharding(hlo->sharding());
+ return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
+ .Reshard(hlo->sharding())
+ .hlo();
+ });
+ return Status::OK();
+ }
+
// Handling cases where both operands' shardings are aligned. We check that
// the LHS batch dimension is not partitioned because it is mapped to the
// output feature dimension in aligned_rhs_sharding, which are not the same
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
index 3354a9c..7c4d816 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
@@ -877,5 +877,13 @@
output_shape, hlo, start_indices, limit_indices, strides));
}
+// Check if a dimension is sharded.
+int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
+ if (sharding.IsTileMaximal()) {
+ return 1;
+ }
+ return sharding.tile_assignment().dim(dim);
+}
+
} // namespace spmd
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
index 5f24566..8389c2f 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
@@ -262,6 +262,9 @@
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
int64 slice_dim, int64 k);
+// Check if a dimension is sharded.
+int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
+
} // namespace spmd
} // namespace xla