blob: 0e1f87a91087df8c29e0d984fe0790e4a8ed0ed8 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace spmd {
namespace {
using hlo_sharding_util::GroupedSharding;
// Returns whether partitioning in the operand only happens in dimensions with
// gather/scatter slice size 1.
absl::optional<std::vector<int64_t>>
GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
const PartitionedHlo& operand, absl::Span<const int64_t> index_map,
absl::Span<const int64_t> slice_size) {
if (operand.sharding().IsTileMaximal()) {
return absl::nullopt;
}
std::vector<int64_t> slice_dims;
int64_t trivial_slice_dims_partitions = 1;
for (int64_t dim : index_map) {
if (slice_size[dim] == 1) {
trivial_slice_dims_partitions *=
operand.sharding().tile_assignment().dim(dim);
slice_dims.push_back(dim);
}
}
if (trivial_slice_dims_partitions == operand.sharding().NumTiles()) {
return slice_dims;
}
return absl::nullopt;
}
// Return an update sharding that is compatible with the indices sharding for
// scatter partitioning.
absl::optional<HloSharding> ComputeUpdateShardingFromIndices(
const PartitionedHlo& updates, const PartitionedHlo& indices,
absl::Span<const int64_t> update_scatter_dims, int64_t index_vector_dim) {
std::vector<int64_t> update_dim_to_index_dim(updates.base_shape().rank(), -1);
std::vector<int64_t> index_dim_to_update_dim(indices.base_shape().rank(), -1);
for (int64_t i = 0; i < update_scatter_dims.size(); ++i) {
int64_t indices_scatter_dim = i < index_vector_dim ? i : i + 1;
update_dim_to_index_dim[update_scatter_dims[i]] = indices_scatter_dim;
index_dim_to_update_dim[indices_scatter_dim] = update_scatter_dims[i];
}
const absl::optional<HloSharding> new_updates_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_update_dim, update_dim_to_index_dim);
return new_updates_sharding;
}
// Return if a scatter is of the supported kind for index+update partitioning.
bool IsSupportedScatterForIndexUpdatePartitioning(
const HloInstruction* scatter) {
auto reduction_opcode = ParseReductionComputation(scatter->to_apply());
if (!reduction_opcode.has_value()) {
return false;
}
switch (*reduction_opcode) {
case HloOpcode::kAdd:
case HloOpcode::kOr:
case HloOpcode::kMultiply:
case HloOpcode::kAnd:
case HloOpcode::kMinimum:
case HloOpcode::kMaximum:
return true;
default:
return false;
}
}
// Returns the min and max for the indices (replicated) in a scatter/gather
// which has the operand partitioned on trivial slice dimensions (slice size 1).
std::pair<HloInstruction*, HloInstruction*>
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
const PartitionedHlo& operand, const PartitionedHlo& replicated_indices,
HloInstruction* partition_id, absl::Span<const int64_t> index_map,
int64_t index_vector_dim, SpmdBuilder* b) {
auto operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand.sharding(), partition_id, b);
// Find the per-dimension index bounds.
std::vector<HloInstruction*> min_indices;
std::vector<HloInstruction*> max_indices;
for (int64_t i = 0; i < index_map.size(); ++i) {
int64_t dim = index_map[i];
int64_t partitions = operand.sharding().tile_assignment().dim(dim);
if (partitions == 1) {
min_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(), 0, b));
max_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(),
operand.base_shape().dimensions(dim), b));
continue;
}
auto offset = operand_offsets[dim];
if (offset->shape().element_type() !=
replicated_indices.base_shape().element_type()) {
offset = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(),
{}),
offset));
}
min_indices.push_back(offset);
auto partition_size_minus_1 =
CreateR0WithType<int32>(replicated_indices.base_shape().element_type(),
operand.hlo()->shape().dimensions(dim) - 1, b);
max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary(
offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1)));
}
// Broadcast the index bounds to the same shape as the indices.
HloInstruction* broadcast_min;
HloInstruction* broadcast_max;
if (index_vector_dim < replicated_indices.base_shape().rank()) {
// The index vector is an R1, we need to reshape individual bounds to
// [1], and concat them if there are more than one.
for (int64_t i = 0; i < min_indices.size(); ++i) {
min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}),
min_indices[i]));
max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}),
max_indices[i]));
}
int64_t slice_dims = max_indices.size();
if (slice_dims > 1) {
min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(min_indices[0]->shape().element_type(),
{slice_dims}),
min_indices, 0));
max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
min_indices[0]->shape(), max_indices, 0));
}
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {index_vector_dim}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {index_vector_dim}));
} else {
CHECK_EQ(max_indices.size(), 1);
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {}));
}
return {broadcast_min, broadcast_max};
}
// Function that tries to perform recursive partitioning of Gather.
StatusOr<HloInstruction*> PartitionGather(
const HloGatherInstruction* gather, PartitionedHlo& operand,
PartitionedHlo& indices, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor);
// Perform partitioning of Gather when the indices are partitioned on the
// non-index vector dimension.
StatusOr<HloInstruction*> PartitionIndexPassthroughPartition(
const HloGatherInstruction* gather, const Shape& output_shape,
const HloSharding& output_sharding, PartitionedHlo& operand,
PartitionedHlo& indices, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) ==
1)) {
std::vector<int64_t> output_dim_to_index_dim(gather->shape().rank(), -1);
std::vector<int64_t> index_dim_to_output_dim(indices.base_shape().rank(),
-1);
for (int64_t i = 0; i < batch_dims.size(); ++i) {
int64_t indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1;
output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim;
index_dim_to_output_dim[indices_batch_dim] = batch_dims[i];
}
absl::InlinedVector<int64_t, 4> index_group_dims;
absl::InlinedVector<int64_t, 4> output_group_dims;
// Collect dimensions that we are sharding in this function, so we can group
// over them for recursive call.
for (int64_t i = 0; i < indices.sharding().TiledDataRank(); ++i) {
if (indices.sharding().tile_assignment().dim(i) != 1) {
index_group_dims.push_back(i);
output_group_dims.push_back(index_dim_to_output_dim[i]);
}
}
// Compute output sharding.
auto pgather_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_output_dim,
output_dim_to_index_dim);
GroupedSharding output_grouped = hlo_sharding_util::GroupShardingOnDims(
*pgather_sharding, output_group_dims);
const int64_t num_tiles = indices.sharding().NumTiles();
GroupedSharding index_grouped =
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), index_group_dims),
output_grouped);
absl::optional<GroupedSharding> operand_grouped;
// Check if we can group partially replicated dims on the operand or
// replicate.
if (operand.sharding().ReplicateOnLastTileDim() &&
operand.sharding().tile_assignment().dimensions().back() % num_tiles ==
0) {
absl::InlinedVector<int64_t, 1> group_dim_shards = {
operand.sharding().tile_assignment().dimensions().back() / num_tiles};
operand_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnDims(
operand.sharding(),
{operand.sharding().tile_assignment().num_dimensions() - 1},
group_dim_shards),
output_grouped);
} else {
operand = operand.Replicate();
}
absl::optional<HloSharding> old_operand_sharding;
if (operand_grouped) {
operand = operand.Reshard(UngroupSharding(*operand_grouped));
old_operand_sharding = operand.hlo()->sharding();
operand.hlo()->set_sharding(operand_grouped->sharding);
} else {
operand = operand.Replicate();
}
const Shape new_output_shape =
GetPerGroupBaseShape(output_grouped, output_shape);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), index_grouped.device_groups, visitor->builder());
const HloSharding old_indices_sharding = indices.hlo()->sharding();
indices.hlo()->set_sharding(index_grouped.sharding);
PartitionedHlo per_group_indices(
indices.hlo(),
GetPerGroupBaseShape(index_grouped, indices.base_shape()),
per_group_partitioner_state);
PartitionedHlo per_group_operand(
operand.hlo(),
operand_grouped
? GetPerGroupBaseShape(*operand_grouped, operand.base_shape())
: operand.base_shape(),
per_group_partitioner_state);
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, per_group_operand, per_group_indices,
new_output_shape, output_grouped.sharding, batch_dims,
slice_sizes, visitor));
indices.hlo()->set_sharding(old_indices_sharding);
if (old_operand_sharding) {
operand.hlo()->set_sharding(*old_operand_sharding);
}
CHECK(pgather_sharding.has_value());
pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped));
VLOG(5) << "[Gather partitioning]: Partitioned as index only";
return PartitionedHlo(pgather, gather->shape(), operand.state())
.Reshard(output_sharding)
.hlo();
}
return nullptr;
}
// Perform partitioning of Gather when the operand is split in a offset
// dimension that is passed through (slice size is the same size of the operand
// dimension).
StatusOr<HloInstruction*> ParititonPassthroughOperand(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, PartitionedHlo& operand,
PartitionedHlo& indices, SpmdPartitioningVisitor* visitor) {
if (operand.sharding().IsTileMaximal()) {
return nullptr;
}
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
if (auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
operand.sharding(), *gather, slice_sizes, output_shape,
operand.base_shape())) {
std::vector<int64_t> pslice_sizes(slice_sizes.begin(), slice_sizes.end());
absl::InlinedVector<int64_t, 4> operand_grouping_dims;
for (int64_t i = 0; i < operand.sharding().TiledDataRank(); ++i) {
if (operand.sharding().tile_assignment().dim(i) != 1) {
operand_grouping_dims.push_back(i);
}
}
const int64_t num_tiles = maybe_passthrough->NumTiles();
absl::InlinedVector<int64_t, 4> output_grouping_dims;
for (int64_t i = 0; i < maybe_passthrough->TiledDataRank(); ++i) {
if (maybe_passthrough->tile_assignment().dim(i) != 1) {
output_grouping_dims.push_back(i);
}
}
for (int64_t i = 0; i < pslice_sizes.size(); ++i) {
if (operand.sharding().tile_assignment().dim(i) > 1) {
pslice_sizes[i] = operand.hlo()->shape().dimensions(i);
}
}
// Merge the sharding from the instruction with the sharding suggested from
// the operand sharding.
hlo_sharding_util::MergeSharding(output_sharding, &*maybe_passthrough,
/*may_combine_partial_sharding=*/true);
GroupedSharding output_grouped = hlo_sharding_util::GroupShardingOnDims(
*maybe_passthrough, output_grouping_dims);
GroupedSharding operand_grouped =
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operand.sharding(), operand_grouping_dims),
output_grouped);
absl::optional<GroupedSharding> indices_grouped;
// See if we can group partially replicated dimensions from the indices
// otherwise replicate it.
if (indices.sharding().ReplicateOnLastTileDim() &&
indices.sharding().tile_assignment().dimensions().back() % num_tiles ==
0) {
absl::InlinedVector<int64_t, 1> group_dim_shards = {
indices.sharding().tile_assignment().dimensions().back() / num_tiles};
indices_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1},
group_dim_shards),
output_grouped);
} else {
indices = indices.Replicate();
}
absl::optional<HloSharding> old_indices_sharding;
if (indices_grouped) {
indices = indices.Reshard(UngroupSharding(*indices_grouped));
old_indices_sharding = indices.hlo()->sharding();
indices.hlo()->set_sharding(indices_grouped->sharding);
} else {
indices = indices.Replicate();
}
auto pshape = GetPerGroupBaseShape(output_grouped, output_shape);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), operand_grouped.device_groups, b);
HloSharding old_operand_sharding = operand.hlo()->sharding();
operand.hlo()->set_sharding(HloSharding::Replicate());
PartitionedHlo per_group_operand(
operand.hlo(),
GetPerGroupBaseShape(operand_grouped, operand.base_shape()),
per_group_partitioner_state);
PartitionedHlo per_group_indices(
indices.hlo(),
indices_grouped
? GetPerGroupBaseShape(*indices_grouped, indices.base_shape())
: indices.base_shape(),
per_group_partitioner_state);
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, per_group_operand, per_group_indices, pshape,
output_grouped.sharding, batch_dims, pslice_sizes,
visitor));
operand.hlo()->set_sharding(old_operand_sharding);
if (old_indices_sharding) {
indices.hlo()->set_sharding(*old_indices_sharding);
}
pgather->set_sharding(*maybe_passthrough);
VLOG(5) << "[Gather partitioning]: Partitioned as operand passthrough "
"offset_dim";
return PartitionedHlo(pgather, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
return nullptr;
}
// Partition a Gather when its sliced in a dimension in the operand that is
// trivially sliced (sliced with slice size of 1).
StatusOr<HloInstruction*> ParititonTrivialIndexedOperandDimension(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, PartitionedHlo& operand,
PartitionedHlo& indices, SpmdPartitioningVisitor* visitor) {
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
absl::optional<std::vector<int64_t>> trivial_slice_dims =
GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes());
if (trivial_slice_dims &&
ShapeSizeInBytes(output_shape) < ShapeSizeInBytes(operand.base_shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
// with full number of indices, where out-of-bounds indices are clamped,
// and masked out with 0 in the result; then we use all-reduce to combine
// results. Although gather will not get faster, we avoided the need to
// replicate the operand.
HloInstruction* indices_min;
HloInstruction* indices_max;
std::tie(indices_min, indices_max) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, operand.state().partition_id, start_index_map,
dnums.index_vector_dim(), b);
// Clamp the indices.
auto adjusted_indices = b->AddInstruction(
HloInstruction::CreateTernary(indices.base_shape(), HloOpcode::kClamp,
indices_min, indices.hlo(), indices_max));
// Adjust the indices by subtracting the offset.
adjusted_indices = b->AddInstruction(
HloInstruction::CreateBinary(indices.base_shape(), HloOpcode::kSubtract,
adjusted_indices, indices_min));
GroupedSharding operand_grouped = hlo_sharding_util::GroupShardingOnDims(
operand.sharding(), *trivial_slice_dims);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), operand_grouped.device_groups, b);
HloSharding original_operand_sharding = operand.hlo()->sharding();
operand.hlo()->set_sharding(HloSharding::Replicate());
PartitionedHlo per_group_operand(
operand.hlo(),
GetPerGroupBaseShape(operand_grouped, operand.base_shape()),
per_group_partitioner_state);
adjusted_indices->set_sharding(HloSharding::Replicate());
PartitionedHlo new_indices(adjusted_indices, adjusted_indices->shape(),
per_group_partitioner_state);
// Gather on adjusted indices.
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, per_group_operand, new_indices, output_shape,
output_sharding, batch_dims, slice_sizes, visitor));
// Mask out invalid results.
auto filter = b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED), indices.hlo(),
indices_min, ComparisonDirection::kLt));
filter = b->AddInstruction(HloInstruction::CreateBinary(
filter->shape(), HloOpcode::kOr, filter,
b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_max, ComparisonDirection::kGt))));
if (dnums.index_vector_dim() < indices.base_shape().rank()) {
std::vector<int64_t> reduced_filter_dims;
for (int64_t i = 0; i < filter->shape().rank(); ++i) {
if (i != dnums.index_vector_dim()) {
reduced_filter_dims.push_back(filter->shape().dimensions(i));
}
}
filter = b->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter,
CreateR0WithType(PRED, false, b), {dnums.index_vector_dim()},
MakeBinaryAdd(PRED, indices.state().module)));
}
std::vector<int64_t> batch_dims;
for (int64_t i = 0; i < pgather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
auto broadcast_filter = b->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter,
batch_dims));
auto filtered = b->AddInstruction(HloInstruction::CreateTernary(
pgather->shape(), HloOpcode::kSelect, broadcast_filter,
CreateZero(pgather->shape(), b), pgather));
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is sharded only on trivially sliced dimensions.
std::vector<int64_t> all_dims(operand.base_shape().rank());
absl::c_iota(all_dims, 0);
auto ar = operand.state().partitioner->AllReduceAlongShardingDims(
b, filtered, original_operand_sharding, operand.state().next_channel_id,
all_dims, operand.state().collective_ops_creator,
MakeBinaryAdd(filtered->shape().element_type(),
operand.state().module));
VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand "
"batch_dim slice";
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
return nullptr;
}
// Partition a gather over a indices dimensions that are cosidered parallel
// (which means that the indices access the operand in a monotonically
// increasing way across the respective operand dimension referenced by the
// index).
StatusOr<HloInstruction*> PartitionIndexParallelDimensions(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, PartitionedHlo& operand,
PartitionedHlo& indices, SpmdPartitioningVisitor* visitor) {
absl::InlinedVector<std::pair<HloInstruction*, HloSharding>, 2>
top_level_sharding_to_reset;
auto cleaner = absl::MakeCleanup([&top_level_sharding_to_reset] {
for (auto& to_reset : top_level_sharding_to_reset) {
to_reset.first->set_sharding(to_reset.second);
}
});
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
// Handle the case where operand is tile maximal. In this case we check if
// the index is not TileMaximal and in this case we use the index sharding
// to drive the output sharding.
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*gather)) {
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims);
HloSharding indices_sharding = gather_sharding->indices_sharding;
HloSharding operand_sharding = gather_sharding->operand_sharding;
GroupedSharding grouped_indices = hlo_sharding_util::GroupShardingOnDims(
indices_sharding, indices_parallel_dims);
GroupedSharding grouped_operand = hlo_sharding_util::GroupShardingOnDims(
operand_sharding, operand_parallel_dims);
int index_dim = dnums.index_vector_dim();
// Construct the required sharding for the new gather we are gonna form.
absl::InlinedVector<int64_t, 4> output_tiling(
output_shape.dimensions_size(), 1);
for (int i = 0, num_output_parallel_dims = output_parallel_dims.size();
i < num_output_parallel_dims; ++i) {
int output_idx = output_parallel_dims[i];
int indices_idx = indices_parallel_dims[i];
output_tiling[output_idx] =
indices_sharding.tile_assignment().dim(indices_idx);
}
operand = operand.Reshard(operand_sharding);
indices = indices.Reshard(indices_sharding);
if (indices_sharding.ReplicateOnLastTileDim()) {
output_tiling.push_back(
indices_sharding.tile_assignment().dimensions().back());
}
Array<int64_t> output_tile_assignment =
indices_sharding.tile_assignment();
output_tile_assignment.Reshape(output_tiling);
// New gather tiling.
HloSharding gather_output_sharding =
indices_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
// Refine output sharding from the operand. it should be inferred from
// operand sharding, so that the partitioned gather can be either 1)
// directly created on the partitioned operand, or 2) recursively created
// without aligning the groups.
if (auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
operand_sharding, operand_parallel_dims),
*gather, slice_sizes, output_shape, operand.base_shape())) {
hlo_sharding_util::MergeShardingIfCompatible(
*maybe_passthrough,
/*minimum_tiles=*/gather_output_sharding.NumTiles() + 1,
&gather_output_sharding);
}
// Construct the offsets for the operand sharding to be used to adjust
// the indices. Because we know the only dimensions partitioned are the
// parallel ones and because the partitioning is the same across indices
// and operands we can apply the offsets on the operands on the indices.
std::vector<HloInstruction*> operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand_sharding, operand.state().partition_id,
b, operand_parallel_dims);
absl::InlinedVector<HloInstruction*, 4> index_offsets;
for (int start_idx = 0; start_idx < dnums.start_index_map_size();
++start_idx) {
HloInstruction* index_offset =
indices.base_shape().dimensions_size() > index_dim
? b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1}),
operand_offsets[dnums.start_index_map(start_idx)]))
: operand_offsets[dnums.start_index_map(start_idx)];
index_offsets.push_back(index_offset);
}
HloInstruction* adjusted_indices = nullptr;
if (indices.base_shape().dimensions_size() > index_dim) {
// Concatenate the offsets for the parallel dimensions to subtract.
adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(S32,
{indices.base_shape().dimensions(index_dim)}),
index_offsets, 0));
} else {
CHECK_EQ(index_offsets.size(), 1);
adjusted_indices = index_offsets[0];
}
if (indices.hlo()->shape().element_type() != PrimitiveType::S32) {
adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(adjusted_indices->shape(),
indices.hlo()->shape().element_type()),
adjusted_indices));
}
if (adjusted_indices->shape().rank() == 0) {
adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {}));
} else {
adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {index_dim}));
}
// Adjust indices by subtracting the offsets based on the partition id.
adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
adjusted_indices));
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), grouped_operand.device_groups, b);
top_level_sharding_to_reset.emplace_back(operand.hlo(),
operand.sharding());
adjusted_indices->set_sharding(grouped_indices.sharding);
operand.hlo()->set_sharding(grouped_operand.sharding);
VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim";
PartitionedHlo per_group_operand(
operand.hlo(),
GetPerGroupBaseShape(grouped_operand, operand.base_shape()),
per_group_partitioner_state);
PartitionedHlo per_group_indices(
adjusted_indices,
GetPerGroupBaseShape(grouped_indices, indices.base_shape()),
per_group_partitioner_state);
GroupedSharding grouped_output = hlo_sharding_util::GroupShardingOnDims(
gather_output_sharding, output_parallel_dims);
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, per_group_operand, per_group_indices,
GetPerGroupBaseShape(grouped_output, output_shape),
grouped_output.sharding, batch_dims, slice_sizes,
visitor));
if (pgather) {
pgather->set_sharding(gather_output_sharding);
return PartitionedHlo(pgather, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
}
}
return nullptr;
}
StatusOr<HloInstruction*> PartitionGather(
const HloGatherInstruction* gather, PartitionedHlo& operand,
PartitionedHlo& indices, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
HloInstruction* partitioned_gather;
// Check if we identify some of the dimensions of the gather as parallel and
// if we have sharded the operand and indices across those dimensions.
// If that's the case then we can partition the gather across such dimensions
// by adjusting the offsets.
TF_ASSIGN_OR_RETURN(partitioned_gather,
PartitionIndexParallelDimensions(
gather, output_shape, output_sharding, batch_dims,
slice_sizes, operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
// Pefrorm passthrough and trivial slice partitioning of the Gather.
TF_ASSIGN_OR_RETURN(partitioned_gather,
ParititonPassthroughOperand(
gather, output_shape, output_sharding, batch_dims,
slice_sizes, operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
// Handle the case where index is patitioned on a dimension that is not the
// index vector dim.
TF_ASSIGN_OR_RETURN(partitioned_gather,
PartitionIndexPassthroughPartition(
gather, output_shape, output_sharding, operand,
indices, batch_dims, slice_sizes, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
TF_ASSIGN_OR_RETURN(partitioned_gather,
ParititonTrivialIndexedOperandDimension(
gather, output_shape, output_sharding, batch_dims,
slice_sizes, operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
HloInstruction* new_gather =
visitor->builder()->AddInstruction(HloInstruction::CreateGather(
output_shape, operand.Replicate().hlo(), indices.Replicate().hlo(),
gather->gather_dimension_numbers(), slice_sizes,
gather->indices_are_sorted()));
new_gather->set_sharding(HloSharding::Replicate());
return new_gather;
}
} // namespace
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
if (hlo->sharding().HasUniqueDevice()) {
return DefaultAction(hlo);
}
auto scatter = Cast<HloScatterInstruction>(hlo);
auto dnums = scatter->scatter_dimension_numbers();
auto operand = GetPartitionedHlo(scatter->operand(0));
auto indices = GetPartitionedHlo(scatter->operand(1));
auto updates = GetPartitionedHlo(scatter->operand(2));
std::vector<int64_t> slice_size(operand.base_shape().rank(), 1);
int64_t num_update_window_dims = 0;
for (int64_t i = 0; i < operand.base_shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = updates.base_shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64_t> update_scatter_dims;
for (int64_t i = 0; i < updates.base_shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.update_window_dims(), i)) {
update_scatter_dims.push_back(i);
}
}
const absl::optional<HloSharding> new_updates_sharding =
ComputeUpdateShardingFromIndices(updates, indices,
absl::MakeConstSpan(update_scatter_dims),
dnums.index_vector_dim());
CHECK(new_updates_sharding.has_value());
auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput(
operand.sharding(), *hlo);
const bool should_shard_index_and_update =
!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == 1);
const bool should_shard_trivial_operand_slices =
GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, scatter_dims_to_operand_dims, slice_size) &&
ShapeSizeInBytes(updates.base_shape()) <
ShapeSizeInBytes(scatter->shape());
// If Passthrough sharding is available the updates are sharded according
// to the *maybe_passthrough sharding, so compare with that size.
const int64_t index_and_update_partitioning_size =
(2 * ShapeSizeInBytes(operand.base_shape()) +
ShapeSizeInBytes(
MakePartitionedShape(updates.base_shape(), *new_updates_sharding)));
const int64_t operand_passthrough_parititoning_size =
!maybe_passthrough ? INT64_MAX
: (2 * ShapeSizeInBytes(operand.hlo()->shape()) +
ShapeSizeInBytes(MakePartitionedShape(
updates.base_shape(), *maybe_passthrough)));
const int64_t operand_trivial_slice_partitioning_size =
!should_shard_trivial_operand_slices
? INT64_MAX
: 2 * ShapeSizeInBytes(operand.hlo()->shape()) +
ShapeSizeInBytes(updates.base_shape()) +
ShapeSizeInBytes(indices.base_shape());
// Compare the size between doing sharding of the indices + updates vs
// sharding of the operand + updates and see which is potentially better size
// wise.
const bool is_better_to_shard_updates_and_indices =
!indices.sharding().IsTileMaximal() &&
index_and_update_partitioning_size <
operand_passthrough_parititoning_size &&
index_and_update_partitioning_size <
operand_trivial_slice_partitioning_size;
if (IsSupportedScatterForIndexUpdatePartitioning(scatter) &&
((is_better_to_shard_updates_and_indices &&
should_shard_index_and_update) ||
operand.sharding().IsTileMaximal())) {
if (should_shard_index_and_update) {
auto reduction_opcode = ParseReductionComputation(scatter->to_apply());
if (!reduction_opcode.has_value()) {
return DefaultAction(hlo);
}
operand = operand.Replicate();
HloInstruction* identity;
switch (*reduction_opcode) {
case HloOpcode::kAdd:
case HloOpcode::kOr:
identity = CreateZero(operand.hlo()->shape(), &b_);
break;
case HloOpcode::kMultiply:
case HloOpcode::kAnd:
identity = CreateOne(operand.hlo()->shape(), &b_);
break;
case HloOpcode::kMinimum:
identity = CreateConstant(
operand.hlo()->shape(),
LiteralUtil::MaxValue(hlo->shape().element_type()), &b_);
break;
case HloOpcode::kMaximum:
identity = CreateConstant(
operand.hlo()->shape(),
LiteralUtil::MinValue(hlo->shape().element_type()), &b_);
break;
default:
return DefaultAction(hlo);
}
updates = updates.Reshard(*new_updates_sharding);
// Update partition_id for partial replicate.
auto partition_id = MakePartitioningState().partition_id;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, &b_);
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
// all-reduce, we use identity operands for all non-zero partitions.
auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(PRED), partition_id));
not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(identity->shape(), PRED),
not_partition_zero, {}));
auto select_operand =
b_.AddInstruction(HloInstruction::HloInstruction::CreateTernary(
identity->shape(), HloOpcode::kSelect, not_partition_zero,
identity, operand.Replicate().hlo()));
auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is not sharded on index_vector_dim.
std::vector<int64_t> all_dims(indices.base_shape().rank());
absl::c_iota(all_dims, 0);
auto all_reduce = operand.state().partitioner->AllReduceAlongShardingDims(
&b_, pscatter, indices.sharding(), indices.state().next_channel_id,
all_dims, collective_ops_creator_, scatter->to_apply());
all_reduce->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(all_reduce, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
// Handle pass through cases if we can use compatible sharding for update.
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(*maybe_passthrough);
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(),
scatter->to_apply(), dnums, scatter->indices_are_sorted(),
scatter->unique_indices()));
pscatter->set_sharding(operand.sharding());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (should_shard_trivial_operand_slices) {
// Operand is sharded on trivial slice dims (update slice size 1). We can
// adjust the indices on each partition by subtracting the offsets. Then
// we execute a scatter on full updated indices, and out-of-bound accesses
// will have no effect on the result as guaranteed by the scatter
// semantics.
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(HloSharding::Replicate());
HloInstruction* indices_min;
HloInstruction* indices_max_unused;
std::tie(indices_min, indices_max_unused) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, MakePartitioningState().partition_id,
scatter_dims_to_operand_dims, dnums.index_vector_dim(), &b_);
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
indices_min));
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), adjusted_indices, updates.hlo(),
scatter->to_apply(), dnums, scatter->indices_are_sorted(),
scatter->unique_indices()));
pscatter->set_sharding(operand.sharding());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
if (hlo->sharding().HasUniqueDevice()) {
return DefaultAction(hlo);
}
auto gather = Cast<HloGatherInstruction>(hlo);
const auto& dnums = gather->gather_dimension_numbers();
auto operand = GetPartitionedHlo(gather->operand(0));
auto indices = GetPartitionedHlo(gather->operand(1));
std::vector<int64_t> batch_dims;
for (int64_t i = 0; i < gather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, operand, indices, gather->shape(),
gather->sharding(), absl::MakeConstSpan(batch_dims),
gather->gather_slice_sizes(), this));
SetPartitionedHlo(
gather, PartitionedHlo(pgather, gather->shape(), MakePartitioningState())
.Reshard(gather->sharding()));
return Status::OK();
}
} // namespace spmd
} // namespace xla