blob: 1b484e018baeb514d10b99404b2c604babc2d07f [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include <float.h>
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/numbers.h"
namespace xla {
namespace spmd {
string SpmdLogger::MakeReport() {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory during transformation *****\n");
std::sort(entries_.begin(), entries_.end(),
[](auto const& entry0, auto const& entry1) {
return entry0.first > entry1.first;
});
for (int64 i = 0;
i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
absl::StrAppend(
&report, "\n ",
tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
entries_[i].second, "\n");
}
return report;
}
void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
const std::vector<HloInstruction*>& group) {
string report = hlo->ToString();
int64 max_value = -1;
for (HloInstruction* inst : group) {
if (!inst->shape().IsArray()) {
continue;
}
max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
absl::StrAppend(&report, " * ", inst->ToString(), "\n");
}
entries_.push_back(std::make_pair(max_value, report));
}
/* static */ string SpmdLogger::ReportBeforePartition(
const HloModule& module, int64 report_instruction_count) {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory usage before partition *****\n");
absl::StrAppend(&report, "\n ** Replicated instructions\n");
absl::StrAppend(&report, ReportMemoryUsage(
module,
[](const HloInstruction* hlo) {
return !hlo->has_sharding() ||
hlo->sharding().IsReplicated();
},
report_instruction_count));
absl::StrAppend(&report, "\n ** All instructions\n");
absl::StrAppend(&report,
ReportMemoryUsage(
module, [](const HloInstruction* hlo) { return true; },
report_instruction_count));
return report;
}
/* static */ string SpmdLogger::ReportAfterPartition(
const HloModule& module, int64 report_instruction_count) {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory usage after partition *****\n");
absl::StrAppend(&report,
ReportMemoryUsage(
module, [](const HloInstruction* hlo) { return true; },
report_instruction_count));
return report;
}
template <typename F>
/* static */ string SpmdLogger::ReportMemoryUsage(
const HloModule& module, const F& filter, int64 report_instruction_count) {
string report;
std::vector<HloInstruction*> instructions;
instructions.reserve(module.instruction_count());
for (auto computation : module.computations()) {
if (computation->IsFusionComputation()) {
continue;
}
for (auto hlo : computation->instructions()) {
if (hlo->shape().IsTuple() ||
ShapeUtil::IsEffectiveScalar(hlo->shape())) {
continue;
}
if (filter(hlo)) {
instructions.push_back(hlo);
}
}
}
const auto add_report = [&](std::vector<HloInstruction*>* insts) {
std::sort(insts->begin(), insts->end(),
[](const HloInstruction* inst0, const HloInstruction* inst1) {
return ShapeSizeInBytes(inst0->shape()) >
ShapeSizeInBytes(inst1->shape());
});
for (int64 i = 0;
i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
absl::StrAppend(&report, " ",
tensorflow::strings::HumanReadableNumBytes(
ShapeSizeInBytes((*insts)[i]->shape())),
" : ", (*insts)[i]->ToString(), "\n");
}
};
add_report(&instructions);
return report;
}
namespace {
// Returns the replica group configuration where each replica belongs to its own
// group.
std::vector<ReplicaGroup> CreateReplicaGroups(int64 num_replicas) {
std::vector<ReplicaGroup> groups(num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
groups[i].add_replica_ids(i);
}
return groups;
}
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
const HloSharding& source, const HloSharding& target) {
if (source.IsTileMaximal() || target.IsTileMaximal() ||
source.tile_assignment().num_dimensions() !=
target.tile_assignment().num_dimensions()) {
return absl::nullopt;
}
int64 source_dim = -1;
int64 target_dim = -1;
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
if (source.tile_assignment().dim(i) > 1 &&
target.tile_assignment().dim(i) == 1) {
if (source_dim != -1) {
return absl::nullopt;
}
source_dim = i;
} else if (source.tile_assignment().dim(i) == 1 &&
target.tile_assignment().dim(i) > 1) {
if (target_dim != -1) {
return absl::nullopt;
}
target_dim = i;
} else if (source.tile_assignment().dim(i) !=
target.tile_assignment().dim(i)) {
return absl::nullopt;
}
}
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
return absl::nullopt;
}
return std::pair(source_dim, target_dim);
}
bool CanReshardWithCollectivePermute(const HloSharding& source,
const HloSharding& target) {
return !source.IsTileMaximal() && !target.IsTileMaximal() &&
source.tile_assignment().dimensions() ==
target.tile_assignment().dimensions() &&
source.tile_assignment() != target.tile_assignment();
}
// Clears all sharding attributes from instructions in the module. This must be
// called only after all SPMD transformation is complete.
Status ClearShardingAttributes(HloModule* module) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* hlo : computation->instructions()) {
// Keep sharding annotation on Infeed and entry parameters since they're
// used by HloReplicationAnalysis later (for ArCrsCombiner).
if (hlo->opcode() == HloOpcode::kInfeed) {
continue;
}
if (hlo->opcode() == HloOpcode::kParameter &&
computation == module->entry_computation()) {
continue;
}
hlo->clear_sharding();
}
}
return Status::OK();
}
} // namespace
HloInstruction* SpmdBuilder::AddInstruction(
std::unique_ptr<HloInstruction> instruction) {
HloInstruction* hlo =
HloComputation::Builder::AddInstruction(std::move(instruction));
if (visiting_hlo_) {
instructions_[visiting_hlo_].push_back(hlo);
}
return hlo;
}
PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
for (auto& entry : cache) {
if (entry.first == target) {
return entry.second;
}
}
cache.emplace_back(target, ReshardNoCache(target));
state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()]
.reshard_cache.emplace_back(sharding(), *this);
return cache.back().second;
}
PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
VLOG(2) << "Resharding " << hlo_->ToString() << " from "
<< hlo_->sharding().ToString() << " to " << target.ToString();
const Shape& shape = hlo_->shape();
CHECK(shape.IsTuple() || !target.IsTuple());
// Tuple shape instructions may have non-tuple sharding, which means that the
// same sharding applies to all the leaves.
if (shape.IsTuple() && !target.IsTuple()) {
return Reshard(target.GetTupleSharding(shape).ValueOrDie());
}
// For a tuple shape, recursively apply Reshard to all the leaves and return
// a tuple instruction.
if (shape.IsTuple()) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
auto element = state_.b->AddInstruction(
HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
element->set_sharding(sharding().GetSubSharding(shape, {i}));
elements.push_back(
PartitionedHlo(
element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
.Reshard(target.GetSubSharding(shape, {i}))
.hlo());
}
auto tuple =
state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
tuple->set_sharding(target);
return PartitionedHlo(tuple, base_shape_, state_);
}
if (sharding() == target) {
return *this;
}
if (shape.element_type() == TOKEN) {
return *this;
}
if (CanReshardWithCollectivePermute(sharding(), target)) {
return ReshardWithCollectivePermute(target);
}
if (auto src_tgt_dims =
GetReshardAllToAllSourceTargetDims(sharding(), target)) {
return ReshardWithAllToAll(target, src_tgt_dims->first,
src_tgt_dims->second);
}
// If not replicated yet, first replicate and then reshard to use one of the
// two implementations below.
if (!sharding().IsReplicated()) {
return Replicate().Reshard(target);
}
// 'Replicated' to 'SingleDevice'.
if (target.IsTileMaximal()) {
auto copy = state_.b->AddInstruction(
HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
copy->set_sharding(target);
return PartitionedHlo(copy, base_shape_, state_);
}
// 'Replicated' to 'Tiled'.
auto padded_hlo =
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
auto shard_shape = MakePartitionedShape(shape, target);
auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo,
MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
shard_shape.dimensions()));
slice->set_sharding(target);
return PartitionedHlo(slice, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::PadWithValue(
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const {
const HloSharding& sharding = hlo_->sharding();
const Shape& shape = hlo_->shape();
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
return *this;
}
CHECK(!sharding.IsTileMaximal());
auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) {
// Comparison: iota + start_index < valid_size
auto iota =
state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
auto broadcast_start_index = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(index_shape, start_index, {}));
auto index_in_full_shape =
state_.b->AddInstruction(HloInstruction::CreateBinary(
index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
ComparisonDirection direction = ComparisonDirection::kLt;
int64 index_limit = base_shape_.dimensions(dim);
if (absl::c_linear_search(left_padded_dims, dim)) {
direction = ComparisonDirection::kGe;
index_limit =
index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
index_limit;
}
auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(index_limit)));
auto broadcast_limit = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(index_shape, limit, {}));
return state_.b->AddInstruction(HloInstruction::CreateCompare(
mask_shape, index_in_full_shape, broadcast_limit, direction));
};
HloInstruction* mask = nullptr;
auto offsets = MakePartitionOffsets(base_shape_, sharding,
state_.partition_id, state_.b);
for (int64 i = 0; i < shape.rank(); ++i) {
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) {
continue;
}
if (mask == nullptr) {
mask = get_mask_for_dim(i, offsets[i]);
} else {
mask = state_.b->AddInstruction(
HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
get_mask_for_dim(i, offsets[i])));
}
}
if (mask == nullptr) {
return *this;
}
auto broadcast_pad_value = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(shape, pad_value, {}));
auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
result->set_sharding(sharding);
return PartitionedHlo(result, base_shape_, state_);
}
absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
PartitionedHlo::ReshardAsWindowedInput(const Window& window,
const HloSharding& target,
HloInstruction* pad_value,
bool mask_invalid_region) {
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
for (auto& entry : cache) {
if (std::get<0>(entry) == target &&
protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
return std::get<2>(entry);
}
}
auto update_cache = [&](WindowedInputShardReturnValue result) {
cache.emplace_back(target, window, std::move(result));
return std::get<2>(cache.back());
};
VLOG(2) << "ReshardAsWindowedInput()\n"
<< "\twindow:" << window_util::ToString(window)
<< "\ttarget sharding:" << target.ToString();
CHECK(!target.IsTileMaximal());
auto partition_ordinals =
MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
auto shard_shape = base_shape_;
std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
base_shape_.rank());
std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
base_shape_.rank());
std::vector<HloInstruction*> dynamic_slice_offset_on_output(
base_shape_.rank(), nullptr);
Window shard_window = window;
auto padded_shape = base_shape_;
std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
std::vector<int64> per_shard_window_counts(base_shape_.rank());
std::vector<int64> explicit_left_padding(base_shape_.rank());
for (int64 i = 0; i < base_shape_.rank(); ++i) {
// Do not pad non-partitioned dimensions.
int64 shard_count = target.tile_assignment().dim(i);
if (shard_count == 1) {
offsets_on_padded_shape[i] = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
continue;
}
const auto& wd = window.dimensions(i);
const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
int64 full_size =
base_shape_.dimensions(i) +
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
wd.padding_high() + wd.padding_low();
if (full_size < dilated_size) {
VLOG(2) << "Failed to reshard window operand because the window size is "
"larger than padded base size";
return absl::nullopt;
}
int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
if (wd.stride() != 1 &&
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
// TODO(yuanzx): Support this case.
VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
return absl::nullopt;
}
// We use explicit padding for full dilations, then use padding_low and
// padding_high on the sharded op for the remaining. padding_low and
// padding_high are now given initial values, which will be later updated if
// dilation is not 1.
auto swd = shard_window.mutable_dimensions(i);
explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
swd->set_padding_low(wd.padding_low() % wd.base_dilation());
swd->set_padding_high(0);
// Calculation for the first element needed on the 'padded-but-not-dilated'
// shape. The start on the dilated shape could be a hole, so we add
// wd.base_dilation() - 1 to the constant term to skip the leading holes.
start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
wd.stride() * per_shard_window_counts[i],
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
int64 dilated_shard_size =
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
wd.stride() * per_shard_window_counts[i],
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
wd.base_dilation());
offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
partition_ordinals[i], state_.b);
auto shard_size_function =
limit_on_padded_calculations[i] - start_on_padded_calculations[i];
int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count);
shard_shape.set_dimensions(i, max_shard_size);
padded_shape.set_dimensions(
i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
// For base dilation, calculate the needed padding_low and padding_high, as
// well as the offset for the output if a dynamic slice is needed after the
// sharded op.
if (wd.base_dilation() != 1) {
// Returns the offset of a shard's first valid element in the dilated
// shard.
auto get_first_valid_element_offset_on_dilated_shard =
[&](int64 shard_ordinal) {
return start_on_padded_calculations[i].Calculate(shard_ordinal) *
wd.base_dilation() +
swd->padding_low() -
wd.stride() * per_shard_window_counts[i] * shard_ordinal;
};
CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
swd->padding_low());
// Determine swd->padding_high.
for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
++shard_ordinal) {
int64 wanted_limit_on_dilated_shard =
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
int64 actual_limit_on_dilated_shard_without_pad_high =
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
(max_shard_size - 1) * wd.base_dilation() + 1;
swd->set_padding_high(std::max<int64>(
swd->padding_high(),
wanted_limit_on_dilated_shard -
actual_limit_on_dilated_shard_without_pad_high));
}
// Determine swd->padding_low and output dynamic slice index.
if (wd.stride() == 1) {
int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0);
bool all_same = true;
for (int64 shard_ordinal = 1; shard_ordinal < shard_count;
++shard_ordinal) {
int64 start =
get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
if (start != swd->padding_low()) {
all_same = false;
}
max_pad_low = std::max(max_pad_low, start);
}
if (!all_same) {
auto start_on_padded_input =
start_on_padded_calculations[i].Calculate(partition_ordinals[i],
state_.b);
// We will calculate
// max_pad_low - (first_window - required_first_window)
// which equals
// required_first_window - (first_window - max_pad_low)
auto first_window_minus_max_pad_low =
MultiplyAddDivideOffsetCalculation(
wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
.Calculate(start_on_padded_input, state_.b);
auto required_first_window =
MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
1)
.Calculate(partition_ordinals[i], state_.b);
dynamic_slice_offset_on_output[i] =
state_.b->AddInstruction(HloInstruction::CreateBinary(
required_first_window->shape(), HloOpcode::kSubtract,
required_first_window, first_window_minus_max_pad_low));
}
swd->set_padding_low(max_pad_low);
} else {
if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
0) {
// General base dilation not yet implemented.
return absl::nullopt;
}
// padding_low on all shards should equal the initially assigned
// swd->padding_low(), i.e., the padding_low() on the original window.
}
}
}
// Returns the output dynamic slice offset when needed, and absl::nullopt
// otherwise.
auto get_dynamic_slice_offset_on_output_if_needed =
[&]() -> absl::optional<std::vector<HloInstruction*>> {
if (absl::c_all_of(
dynamic_slice_offset_on_output,
[](HloInstruction* offset) { return offset == nullptr; })) {
return absl::nullopt;
}
auto zero = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
if (dynamic_slice_offset_on_output[i] == nullptr) {
dynamic_slice_offset_on_output[i] = zero;
}
}
return dynamic_slice_offset_on_output;
};
// If the currrent HLO is replicated, pad then slice.
if (sharding().IsReplicated()) {
PaddingConfig padding_config;
for (int64 i = 0; i < base_shape_.rank(); ++i) {
auto padding_config_dim = padding_config.add_dimensions();
padding_config_dim->set_interior_padding(0);
// Do not pad non-partitioned dimensions.
if (target.tile_assignment().dim(i) == 1) {
padding_config_dim->set_edge_padding_low(0);
padding_config_dim->set_edge_padding_high(0);
continue;
}
padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
explicit_left_padding[i] -
base_shape_.dimensions(i));
}
auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
? hlo_
: state_.b->AddInstruction(HloInstruction::CreatePad(
padded_shape, hlo_, pad_value, padding_config));
auto sharded_input =
state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo, offsets_on_padded_shape,
shard_shape.dimensions()));
return update_cache(WindowedInputShardReturnValue{
sharded_input, shard_window,
get_dynamic_slice_offset_on_output_if_needed()});
}
if (target != sharding()) {
return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
}
// Halo exchange.
HloInstruction* visiting_hlo = hlo_;
auto original_shard_shape = MakePartitionedShape(base_shape_, target);
std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
// TODO(yuanzx): We are concatenating on each sharded dimension one at time,
// and in the second dimension (and beyond) we create halos by slicing the
// concat in the previous dimension, which is not optimal. We should generate
// halos only concating slices, instead of slicing concats.
for (int dim = 0; dim < base_shape_.rank(); ++dim) {
int64 shard_count = target.tile_assignment().dim(dim);
if (shard_count == 1) {
continue;
}
int64 input_shard_size =
CeilOfRatio(base_shape_.dimensions(dim), shard_count);
// Left halo. The size of the halo is derived by subtracting the first read
// element offset of the i'th partition from the limit of the (i-1)'th
// partition.
MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
input_shard_size, explicit_left_padding[dim], 1);
left_halo_size_functions[dim] =
shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
// Right halo.
MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
right_halo_size_functions[dim] =
limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
auto resharded = ExchangeHaloAndGetValidData(
visiting_hlo, base_shape_, left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding[dim],
padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
state_.collective_ops_creator, state_.next_channel_id, state_.b,
mask_invalid_region);
if (!resharded) {
VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
"is beyond the neighbor.";
return Replicate().ReshardAsWindowedInput(window, target, pad_value);
}
visiting_hlo = *resharded;
}
return update_cache(WindowedInputShardReturnValue{
visiting_hlo, shard_window,
get_dynamic_slice_offset_on_output_if_needed()});
}
PartitionedHlo PartitionedHlo::Replicate() {
const HloSharding& sharding = hlo_->sharding();
const Shape& shape = hlo_->shape();
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
if (sharding.IsReplicated()) {
return *this;
}
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
for (auto& entry : cache) {
if (entry.first.IsReplicated()) {
return entry.second;
}
}
auto update_cache = [&](PartitionedHlo resharded) {
state_.reshard_cache->per_hlo_cache[resharded.hlo()]
.reshard_cache.emplace_back(sharding, *this);
cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
return cache.back().second;
};
// 'Single Device' to 'Repliated'.
if (sharding.IsTileMaximal()) {
return update_cache(Broadcast());
}
// 'Tiled' to 'Replicated'.
HloInstruction* result = nullptr;
if (state_.collective_ops_creator.create_cross_partition_all_gather) {
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding,
NewChannel());
}
Shape padded_base_shape = shape;
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
padded_base_shape.set_dimensions(
i, shape.dimensions(i) * sharding.tile_assignment().dim(i));
}
if (result == nullptr) {
auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(shape.element_type())));
auto zero_bcast = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(padded_base_shape, zero, {}));
auto dus =
state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
padded_base_shape, zero_bcast, hlo_,
MakePartitionOffsets(padded_base_shape, sharding,
state_.partition_id, state_.b)));
HloComputation* reduction =
MakeBinaryAdd(shape.element_type(), state_.module);
auto all_reduce =
state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, dus, reduction, NewChannel());
result = all_reduce;
}
if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) {
std::vector<int64> start_indices(shape.rank(), 0);
std::vector<int64> strides(shape.rank(), 1);
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
base_shape_, result, start_indices, base_shape_.dimensions(), strides));
}
result->set_sharding(HloSharding::Replicate());
return update_cache(PartitionedHlo(result, base_shape_, state_));
}
PartitionedHlo PartitionedHlo::Broadcast() const {
const Shape& shape = hlo_->shape();
const HloSharding& sharding = hlo_->sharding();
CHECK(sharding.HasUniqueDevice());
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
bcast_shape,
state_.b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
ComparisonDirection::kEq)),
{}));
auto zero = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
auto zero_bcast = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(shape, zero, {}));
auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
HloComputation* reduction =
MakeBinaryAdd(shape.element_type(), state_.module);
auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, operand, reduction, NewChannel());
result->set_sharding(HloSharding::Replicate());
return PartitionedHlo(result, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
int64 source_dim,
int64 target_dim) const {
const int64 group_size = sharding().tile_assignment().dim(source_dim);
// If the device order is different in the target, fix the order with
// ReshardWithCollectivePermute.
std::vector<int64> xpose_dims(target.tile_assignment().num_dimensions());
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
xpose_dims[source_dim] = target_dim;
xpose_dims[target_dim] = source_dim;
auto input_sharding_fixed_device_order =
hlo_sharding_util::TransposeSharding(target, xpose_dims);
if (input_sharding_fixed_device_order != sharding()) {
auto fixed_order =
ReshardWithCollectivePermute(input_sharding_fixed_device_order);
return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim);
}
auto padded_hlo =
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
// The order of ids in the group must follow the target sharding.
std::vector<ReplicaGroup> groups(target.tile_assignment().num_elements() /
group_size);
target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
int64 group_id = 0;
for (int64 dim = 0; dim < indices.size(); ++dim) {
if (dim == target_dim) {
continue;
}
group_id *= target.tile_assignment().dim(dim);
group_id += indices[dim];
}
groups[group_id].add_replica_ids(device);
});
HloInstruction* result = nullptr;
// Split along the split dimension (target_dim) of the all-to-all
// output.
std::vector<int64> dimensions;
for (int64 i = 0; i < base_shape_.rank(); ++i) {
if (i == target_dim) {
dimensions.push_back(group_size);
dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
} else {
dimensions.push_back(padded_hlo->shape().dimensions(i));
}
}
auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
padded_hlo));
// After the reshape, it is guaranteed to have at least 3 dimensions.
auto all_to_all =
state_.collective_ops_creator.create_cross_partition_all_to_all(
state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
// Reorder the split dimension of the reshape to be located in front of the
// input partition dimension, so the two dimensions can be combined.
int64 new_source_dim =
(target_dim < source_dim) ? source_dim + 1 : source_dim;
std::vector<int64> permutation;
for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
if (i == target_dim) {
continue;
}
if (i == new_source_dim) {
permutation.push_back(target_dim);
}
permutation.push_back(i);
}
auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
.ValueOrDie(),
all_to_all, permutation));
// Combine the split dimension and the input partition dimension.
auto new_shape = ShapeInference::InferAllToAllShape(
padded_hlo->shape(), target_dim, source_dim, group_size)
.ValueOrDie();
result = state_.b->AddInstruction(
HloInstruction::CreateReshape(new_shape, transpose));
const Shape result_shape = MakePartitionedShape(base_shape_, target);
if (result_shape != result->shape()) {
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
result_shape, result, std::vector<int64>(result_shape.rank(), 0),
result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
}
result->set_sharding(target);
return PartitionedHlo(result, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
const HloSharding& target) const {
CHECK(CanReshardWithCollectivePermute(sharding(), target))
<< sharding().ToString() << " to " << target.ToString();
std::vector<std::pair<int64, int64>> src_dst_pairs;
sharding().tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 src_device) {
int64 dst_device = target.tile_assignment()(indices);
if (dst_device != src_device) {
src_dst_pairs.emplace_back(src_device, dst_device);
}
});
auto cp =
state_.collective_ops_creator.create_cross_partition_collective_permute(
state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
cp->set_sharding(target);
return PartitionedHlo(cp, base_shape_, state_);
}
SpmdPartitioningVisitor::SpmdPartitioningVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
SpmdPartitioner* partitioner)
: changed_(false),
module_(computation->parent()),
num_partitions_(num_partitions),
num_replicas_(num_replicas),
collective_ops_creator_(collective_ops_creator),
next_channel_id_(next_channel_id),
b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
partition_id_(collective_ops_creator_.create_partition_id(&b_)),
logger_(logger),
options_(std::move(options)),
partitioner_(partitioner) {}
Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
if (hlo->HasSideEffect()) {
return Unimplemented("Side-effect ops cannot be replicated: %s",
hlo->ToString());
}
if (hlo->IsElementwise() && hlo->operand_count() > 0) {
return HandleElementwise(hlo);
}
if (!hlo->sharding().IsTileMaximal()) {
VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
<< hlo->ToString();
for (int64 i = 0; i < hlo->operand_count(); ++i) {
VLOG(1) << " operand " << i
<< " sharding:" << hlo->operand(i)->sharding().ToString();
}
}
// If the instruction cannot be partitioned, replicate the instruction unless
// the instruction has side-effect.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo());
}
auto clone =
b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(HloSharding::Replicate());
clone->set_metadata(hlo->metadata());
SetPartitionedHlo(hlo,
PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding()));
return Status::OK();
}
Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
visiting_hlo_ = hlo;
b_.set_visiting_hlo(hlo);
return Status::OK();
}
Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
b_.derived_instructions(hlo));
visiting_hlo_ = nullptr;
b_.set_visiting_hlo(nullptr);
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
const int64 dimension = hlo->concatenate_dimension();
if (sharding.tile_assignment().dim(dimension) == 1) {
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(
hlo->CloneWithNewOperands(shard_shape, new_operands));
});
return Status::OK();
}
// If the concatenate dimension is along one of the partitioned dimensions,
// allocate the full output shape, each partition updates its owned region,
// all-reduce across partitions, and then slice its output region.
// We currently don't support subgroup all-reduce along partitions, so more
// than 1 partitioned dimensions is not supported.
if (sharding.tile_assignment().dim(dimension) != num_partitions_) {
return DefaultAction(hlo);
}
// temp_output_shape is the output shape where the concatenate dimension
// is changed to the full (and padded to shard count) dimension size.
auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
temp_output_shape.set_dimensions(
dimension, temp_output_shape.dimensions(dimension) *
sharding.tile_assignment().dim(dimension));
auto temp_output = CreateZero(temp_output_shape, &b_);
// Offset of each operand along the concatenate dimension.
int64 offset = 0;
for (HloInstruction* operand : hlo->operands()) {
auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
std::vector<HloInstruction*> start_indices(
hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(S32))));
start_indices[dimension] =
MultiplyAddDivideOffsetCalculation(
spmd_operand->shape().dimensions(dimension), offset, 1)
.Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_,
&b_)[dimension],
&b_);
temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
temp_output_shape, temp_output, spmd_operand, start_indices));
offset += operand->shape().dimensions(dimension);
}
auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
SetPartitionedHlo(hlo, [&] {
auto start_indices =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
shard_shape.dimensions(dimension), 0, 1)
.Calculate(start_indices[dimension], &b_);
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
});
return Status::OK();
}
// If partitioning in the operand only happens in dimensions in passthrough
// dimensions (offset dimensions in the gather output (or scatter update) that
// have the same size as the operand), returns the corresponding output (or
// update) sharding by passing through the input sharding.
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
const PartitionedHlo& operand, const Shape& update_or_gather_shape,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (operand.sharding().IsTileMaximal()) {
return operand.sharding();
}
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
int64 dim_partitions = operand.sharding().tile_assignment().dim(i);
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
if (dim_partitions > 1) {
return absl::nullopt;
}
collapsed++;
continue;
}
if (slice_size[i] != operand.base_shape().dimensions(i) &&
dim_partitions > 1) {
return absl::nullopt;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[offset_dim] = dim_partitions;
}
Array<int64> tile_assignment = operand.sharding().tile_assignment();
tile_assignment.Reshape(passthrough_tile);
return HloSharding::Tile(tile_assignment);
}
// Returns whether partitioning in the operand only happens in dimensions with
// gather/scatter slice size 1.
bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
const PartitionedHlo& operand, absl::Span<const int64> index_map,
absl::Span<const int64> slice_size, int64 num_partitions) {
if (operand.sharding().IsTileMaximal()) {
return false;
}
int64 trivial_slice_dims_partitions = 1;
for (int64 dim : index_map) {
if (slice_size[dim] == 1) {
trivial_slice_dims_partitions *=
operand.sharding().tile_assignment().dim(dim);
}
}
return trivial_slice_dims_partitions == num_partitions;
}
// Returns the min and max for the indices (replicated) in a scatter/gather
// which has the operand partitioned on trivial slice dimensions (slice size 1).
std::pair<HloInstruction*, HloInstruction*>
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
const PartitionedHlo& operand, const PartitionedHlo& replicated_indices,
HloInstruction* partition_id, absl::Span<const int64> index_map,
int64 index_vector_dim, SpmdBuilder* b) {
auto operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand.sharding(), partition_id, b);
// Find the per-dimension index bounds.
std::vector<HloInstruction*> min_indices;
std::vector<HloInstruction*> max_indices;
for (int64 i = 0; i < index_map.size(); ++i) {
int64 dim = index_map[i];
int64 partitions = operand.sharding().tile_assignment().dim(dim);
if (partitions == 1) {
min_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(), 0, b));
max_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(),
operand.base_shape().dimensions(dim), b));
continue;
}
auto offset = operand_offsets[dim];
if (offset->shape().element_type() !=
replicated_indices.base_shape().element_type()) {
offset = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(),
{}),
offset));
}
min_indices.push_back(offset);
auto partition_size_minus_1 =
CreateR0WithType<int32>(replicated_indices.base_shape().element_type(),
operand.hlo()->shape().dimensions(dim) - 1, b);
max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary(
offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1)));
}
// Broadcast the index bounds to the same shape as the indices.
HloInstruction* broadcast_min;
HloInstruction* broadcast_max;
if (index_vector_dim < replicated_indices.base_shape().rank()) {
// The index vector is an R1, we need to reshape individual bounds to
// [1], and concat them if there are more than one.
for (int64 i = 0; i < min_indices.size(); ++i) {
min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}),
min_indices[i]));
max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}),
max_indices[i]));
}
int64 slice_dims = max_indices.size();
if (slice_dims > 1) {
min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(min_indices[0]->shape().element_type(),
{slice_dims}),
min_indices, 0));
max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
min_indices[0]->shape(), max_indices, 0));
}
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {index_vector_dim}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {index_vector_dim}));
} else {
CHECK_EQ(max_indices.size(), 1);
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {}));
}
return {broadcast_min, broadcast_max};
}
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
auto scatter = Cast<HloScatterInstruction>(hlo);
auto dnums = scatter->scatter_dimension_numbers();
auto operand = GetPartitionedHlo(scatter->operand(0));
auto indices = GetPartitionedHlo(scatter->operand(1));
auto updates = GetPartitionedHlo(scatter->operand(2));
std::vector<int64> slice_size(operand.base_shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = updates.base_shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
if (!operand.sharding().IsTileMaximal()) {
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
operand, updates.base_shape(), inserted_window_dims,
scatter_dims_to_operand_dims, update_window_dims, slice_size);
// Handle pass through cases if we can use compatible sharding for update.
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(*maybe_passthrough);
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(),
scatter->to_apply(), dnums, scatter->indices_are_sorted(),
scatter->unique_indices()));
pscatter->set_sharding(*maybe_passthrough);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, scatter_dims_to_operand_dims, slice_size,
num_partitions_) &&
ShapeSizeInBytes(updates.base_shape()) <
ShapeSizeInBytes(scatter->shape())) {
// Operand is sharded on trivial slice dims (update slice size 1). We can
// adjust the indices on each partition by subtracting the offsets. Then
// we execute a scatter on full updated indices, and out-of-bound accesses
// will have no effect on the result as guaranteed by the scatter
// semantics.
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(HloSharding::Replicate());
HloInstruction* indices_min;
HloInstruction* indices_max_unused;
std::tie(indices_min, indices_max_unused) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, partition_id_, scatter_dims_to_operand_dims,
dnums.index_vector_dim(), &b_);
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
indices_min));
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), adjusted_indices,
updates.hlo(), scatter->to_apply(), dnums,
scatter->indices_are_sorted(), scatter->unique_indices()));
pscatter->set_sharding(operand.sharding());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
// Create a window config to represent the slice.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(hlo->slice_strides(i));
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_padding_low(-hlo->slice_starts(i));
dim->set_padding_high(hlo->slice_limits(i) -
hlo->operand(0)->shape().dimensions(i));
dim->set_base_dilation(1);
}
auto reshard_operand = operand.ReshardAsWindowedInput(
window, sharding,
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
const Shape& operand_shape = reshard_operand->sharded_input->shape();
std::vector<int64> start_indices = hlo->slice_starts();
std::vector<int64> limit_indices = hlo->slice_limits();
std::vector<int64> strides = hlo->slice_strides();
bool need_slice = false;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
auto dim = reshard_operand->shard_window.dimensions(i);
start_indices[i] = -dim.padding_low();
limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
if (start_indices[i] != 0 || strides[i] != 1 ||
limit_indices[i] != operand_shape.dimensions(i)) {
need_slice = true;
}
}
SetPartitionedHlo(hlo, [&] {
if (need_slice) {
auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
return b_.AddInstruction(HloInstruction::CreateSlice(
shard_shape, reshard_operand->sharded_input, start_indices,
limit_indices, strides));
}
return reshard_operand->sharded_input;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
HloSharding sharding = hlo->sharding();
// Special handling for sort in TopK when first operand partitioined at
// sort dimension.
auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
if (k.has_value()) {
// When the first operand partitioned at sort dimension:
// 1. Partition sort computation to different partitions;
// 2. Slice TopK value and index from different partitions;
// 3. Gather and replicate value and index from different partitions,
// the shape of replicated value and index will be
// [batch_size, ..., partition_count * k, ...];
// 4. Final sort uses replicated value and index from different partitions
// as input.
// GetTupleElement and Slice after the non-partitoned sort won't change
// at this point, as HandleGetTupleElement and HandleSlice will update them.
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
const int64 sort_dim = sort->sort_dimension();
auto input = hlo->operand(0);
auto index = hlo->operand(1);
const HloSharding& input_sharding = input->sharding();
const int64 partition_count =
input_sharding.tile_assignment().dim(sort_dim);
const int64 input_size = input->shape().dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
const auto element_type = input->shape().element_type();
const auto index_type = index->shape().element_type();
// Partition and pad input and index.
// Pad input with minimal value.
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Pad index with max value.
auto partitioned_index =
GetPartitionedHlo(index)
.Reshard(input_sharding)
.PadWithValue(CreateLastWithType(index_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes the padded shape.
std::vector<int64> replicated_dimensions(
input->shape().dimensions().begin(), input->shape().dimensions().end());
replicated_dimensions[sort_dim] = per_partition_size * partition_count;
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
// Partition original topk to different shards.
auto topk_sharding =
input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
// Get value from first sort.
HloInstruction* value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
HloInstruction* index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
// Slice top K value from the first partitioned sort.
replicated_dimensions[sort_dim] = k.value() * partition_count;
auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
slice_input->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_input(
slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_input =
partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
// Slice top K index from the first parttioned sort.
auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
slice_index->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_index(
slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_index =
partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions.
const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
auto final_sort = b_.AddInstruction(HloInstruction::CreateSort(
final_topk_shape, sort_dim,
{replicated_slice_input, replicated_slice_index}, sort->to_apply(),
sort->is_stable()));
final_sort->set_sharding(HloSharding::Replicate()
.GetTupleSharding(final_sort->shape())
.ValueOrDie());
PartitionedHlo replicated_sort(final_sort, final_topk_shape,
MakePartitioningState());
SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
return Status::OK();
}
if (hlo->shape().IsTuple()) {
// Check that all elements are sharded in the same way.
if (hlo->shape().tuple_shapes_size() == 0) {
return DefaultAction(hlo);
}
sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
for (int64 i = 1; i < hlo->operand_count(); ++i) {
if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
return DefaultAction(hlo);
}
}
}
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
for (int64 dim : hlo->dimensions()) {
if (sharding.tile_assignment().dim(dim) > 1) {
return DefaultAction(hlo);
}
}
// Reshard operands to the same as the output.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "SPMDFullToShardShape") {
// This op switches from auto partitioning to manual partitioning.
auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
input_partitioned = input_partitioned.PadWithValue(
CreateR0WithType(hlo->shape().element_type(), 0, &b_));
}
auto input = input_partitioned.hlo();
CHECK(hlo->sharding().IsReplicated());
CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
SetPartitionedHlo(hlo, [&] { return copy; });
return Status::OK();
}
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
// This op switches from manual partitioning to auto partitioning.
auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
CHECK(input->sharding().IsReplicated());
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
CHECK(ShapeUtil::Compatible(
copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
SetPartitionedHlo(hlo, [&] { return copy; });
return Status::OK();
}
if (hlo->custom_call_target() != "TopK") {
return DefaultAction(hlo);
}
if (!hlo->operand(0)->has_sharding()) {
return DefaultAction(hlo);
}
const HloSharding& sharding = hlo->operand(0)->sharding();
if (sharding.IsTileMaximal() || sharding.IsReplicated()) {
return DefaultAction(hlo);
}
const int64 sort_dim = 1;
const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
if (shard_count <= 1) {
return DefaultAction(hlo);
}
const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0);
const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
if (k >= per_partition_size) {
return DefaultAction(hlo);
}
auto input = hlo->operand(0);
const auto element_type = input->shape().element_type();
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes [batch_size, k * shard_count].
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
{batch_size, k * shard_count}),
ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
auto custom_call_sharding =
sharding.GetTupleSharding(replicated_shape).ValueOrDie();
auto shard_shape =
MakePartitionedShape(replicated_shape, custom_call_sharding);
auto topk = b_.AddInstruction(
hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
topk->set_sharding(custom_call_sharding);
// Partition customcall.
PartitionedHlo partitioned_topk(topk, replicated_shape,
MakePartitioningState());
topk = partitioned_topk.hlo();
// Get value from TopK.
HloInstruction* value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
value_gte->set_sharding(sharding);
// Partition GetTupleElement of value.
PartitionedHlo value_partitioned_gte(
value_gte, partitioned_topk.base_shape().tuple_shapes(0),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_value_gte =
value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
// Get index from TopK.
HloInstruction* index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
partition_id_));
// Add per partition offset to index, index returned from CustomCall always
// starts from 0.
auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
index_gte->shape(),
b_.AddInstruction(HloInstruction::CreateBinary(
partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(per_partition_size))))),
{}));
index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
index_gte->set_sharding(sharding);
// Parttion GetTupleElement of index.
PartitionedHlo index_partitioned_gte(
index_gte, partitioned_topk.base_shape().tuple_shapes(1),
MakePartitioningState());
// Reshard index to be replicated.
auto replicated_index_gte =
index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions. The reason to use Sort instead of CustomCall TopK
// is CustomCall only takes value as input. There will be an extra Gather
// to get the correct index if CustomCall is used here.
// Create comparator for the sort.
XlaBuilder b("Sort.Compare");
XlaComputation comparator = CreateScalarComparisonComputation(
"compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
&b);
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comparator.proto(), config));
HloCloneContext context(module_);
auto compare_computation =
module_->DeepCloneComputation(new_module->entry_computation(), &context);
auto sort = b_.AddInstruction(HloInstruction::CreateSort(
replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
compare_computation, true));
sort->set_sharding(
HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie());
PartitionedHlo replicated_sort(sort, replicated_shape,
MakePartitioningState());
// Slice value and index from top-k for output.
HloInstruction* sort_value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
0));
HloInstruction* sort_index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
1));
// Slice value from final sort.
HloInstruction* slice_sort_value =
SliceFirstK(sort_value_gte, &b_, sort_dim, k);
// Slice index from final sort.
HloInstruction* slice_index_value =
SliceFirstK(sort_index_gte, &b_, sort_dim, k);
auto create_tuple = b_.AddInstruction(
HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
create_tuple->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(),
MakePartitioningState())
.Reshard(hlo->sharding()));
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
std::vector<int64> inverse_dimensions(hlo->shape().rank());
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
inverse_dimensions[hlo->dimensions(i)] = i;
}
auto desired_operand_sharding =
hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
auto operand = GetPartitionedHlo(hlo->operand(0))
.Reshard(desired_operand_sharding)
.hlo();
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0));
// The output shape is the source and the operand shape is the target to get
// the aligned sharding for the operand.
auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
if (desired_operand_sharding.has_value()) {
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
});
return Status::OK();
}
// Try use halo exchange for certain split-dim/merge-dims cases.
// ReshapeSharding failed in these cases probably due to uneven partitioning,
// where halo exchange could help. Specifically we check the following
// conditions to detect supported cases:
// 1) Both input and output are partitioned on one dimension.
// 2) The combined size of dimensions before the partitioned dimension are the
// same on input and output. This means we don't need to consider the major
// dimensions.
// 3) Let A = the input size on the partitioned dimension, and
// B = the output size on the partitioned dimension; then
// either A % B == 0 (split dim) or B % A == 0 (merge dims).
auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
return DefaultAction(hlo);
}
int64 input_sharded_dim = *maybe_input_sharded_dim;
int64 output_sharded_dim = *maybe_output_sharded_dim;
// Check that the major dims before the sharded dim have the same total size
// for input and output.
int64 input_major_dims_size = 1;
for (int64 i = 0; i < input_sharded_dim; ++i) {
input_major_dims_size *= operand.base_shape().dimensions(i);
}
int64 output_major_dims_size = 1;
for (int64 i = 0; i < output_sharded_dim; ++i) {
output_major_dims_size *= hlo->shape().dimensions(i);
}
if (input_major_dims_size != output_major_dims_size) {
return DefaultAction(hlo);
}
// Fix potential device ordering mismatch in tile assignment.
Array<int64> new_input_tile_assignment = sharding.tile_assignment();
new_input_tile_assignment.Reshape(
operand.sharding().tile_assignment().dimensions());
operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment));
int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim);
auto input_shard_shape =
MakePartitionedShape(operand.base_shape(), operand.sharding());
auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
if (input_dim_size % output_dim_size == 0) {
// Split dim.
int64 split_factor = input_dim_size / output_dim_size;
int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim);
// Use halo exchange to fix misaligned data.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_base_dilation(1);
dim->set_padding_low(0);
if (i == input_sharded_dim) {
dim->set_padding_high(output_shard_size * split_factor *
num_partitions_ -
input_dim_size);
} else {
dim->set_padding_high(0);
}
}
auto reshard_operand = operand.ReshardAsWindowedInput(
window, operand.sharding(),
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
CHECK_EQ(
reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
output_shard_size * split_factor);
SetPartitionedHlo(hlo, [&] {
// Do a local reshape.
return b_.AddInstruction(HloInstruction::CreateReshape(
output_shard_shape, reshard_operand->sharded_input));
});
return Status::OK();
} else if (output_dim_size % input_dim_size == 0) {
// Merge dims.
int64 merge_factor = output_dim_size / input_dim_size;
// First reshape locally. (The sharded dimension could include padded data.)
auto tmp_shard_shape = output_shard_shape;
tmp_shard_shape.set_dimensions(
output_sharded_dim,
input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
auto tmp_reshape = b_.AddInstruction(
HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
tmp_reshape->set_metadata(hlo->metadata());
tmp_reshape->set_sharding(hlo->sharding());
auto tmp_full_shape = tmp_shard_shape;
tmp_full_shape.set_dimensions(
output_sharded_dim,
tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_);
auto tmp_output =
PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
// Use halo exchange to fix misaligned data.
Window window;
for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_base_dilation(1);
dim->set_padding_low(0);
if (i == output_sharded_dim) {
dim->set_padding_high(output_dim_size -
tmp_shard_shape.dimensions(output_sharded_dim) *
num_partitions_);
} else {
dim->set_padding_high(0);
}
}
auto reshard_output = tmp_output.ReshardAsWindowedInput(
window, sharding,
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_output.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
CHECK_EQ(
reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
output_shard_shape.dimensions(output_sharded_dim));
SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
return Status::OK();
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] {
int64 dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
auto iota = b_.AddInstruction(HloInstruction::CreateIota(
MakePartitionedShape(hlo->shape(), sharding), dimension));
if (sharding.tile_assignment().dim(dimension) > 1) {
auto partition_ordinals =
MakeTiledPartitionOrdinals(sharding, partition_id_, &b_);
auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
partition_ordinals[dimension], multiplier));
if (iota->shape().element_type() != S32) {
offset = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
}
auto broadcast = b_.AddInstruction(
HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
return b_.AddInstruction(HloInstruction::CreateBinary(
iota->shape(), HloOpcode::kAdd, iota, broadcast));
}
return iota;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
int64 device = hlo->sharding().GetUniqueDevice();
const HloSharding sharding = HloSharding::AssignDevice(device);
std::vector<HloInstruction*> operands;
std::vector<Shape> operand_shapes;
for (const HloInstruction* operand : hlo->operands()) {
operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
operand_shapes.push_back(operand->shape());
}
auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
auto on_device = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device,
ComparisonDirection::kEq));
SpmdBuilder true_b("true_computation", visiting_hlo_);
HloComputation* true_computation;
{
auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, operand_shape, "true_branch_param"));
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < operands.size(); ++i) {
new_operands.push_back(true_b.AddInstruction(
HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
}
auto root = true_b.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
}
SpmdBuilder false_b("false_computation", visiting_hlo_);
HloComputation* false_computation;
{
false_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, operand_shape, "false_branch_param"));
auto root = CreateZero(hlo->shape(), &false_b);
false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
hlo->shape(), pred, operand, true_computation, operand,
false_computation));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
return HandleElementwise(hlo);
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto& operand = GetPartitionedHlo(hlo->operand(0));
// Tiled output.
std::vector<int64> wanted_input_tile_size(operand.base_shape().rank());
std::vector<int64> sharded_new_dims;
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
wanted_input_tile_size[i] =
hlo->sharding().tile_assignment().dim(hlo->dimensions(i));
}
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_linear_search(hlo->dimensions(), i) &&
hlo->sharding().tile_assignment().dim(i) > 1) {
sharded_new_dims.push_back(i);
}
}
if (sharded_new_dims.empty()) {
// The new dimensions are replicated, so that we can do the adjustment on
// the input.
Array<int64> wanted_input_tile_assignment(wanted_input_tile_size);
wanted_input_tile_assignment.Each(
[&](absl::Span<const int64> indices, int64* val) {
std::vector<int64> indices_in_broadcast(hlo->shape().rank(), 0);
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
indices_in_broadcast[hlo->dimensions(i)] = indices[i];
}
*val = hlo->sharding().tile_assignment()(indices_in_broadcast);
});
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
{operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment))
.hlo()}));
});
} else {
auto input = operand.Reshard(HloSharding::Replicate()).hlo();
// We pad and shard the input first, then broadcast to the final shard
// shape.
auto output_offsets =
MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_);
std::vector<HloInstruction*> input_offsets(operand.base_shape().rank());
auto output_shard_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
auto input_shard_shape = input->shape();
auto padded_input_shape = input->shape();
for (int64 i = 0; i < input_offsets.size(); ++i) {
input_offsets[i] = output_offsets[hlo->dimensions(i)];
input_shard_shape.set_dimensions(
i, output_shard_shape.dimensions(hlo->dimensions(i)));
padded_input_shape.set_dimensions(
i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) *
input_shard_shape.dimensions(i));
}
auto padded_input = PadToShape(input, padded_input_shape, &b_);
auto input_shard =
ShapeUtil::Compatible(input_shard_shape, padded_input->shape())
? padded_input
: b_.AddInstruction(HloInstruction::CreateDynamicSlice(
input_shard_shape, padded_input, input_offsets,
input_shard_shape.dimensions()));
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(
hlo->CloneWithNewOperands(output_shard_shape, {input_shard}));
});
}
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
const Literal& literal = hlo->literal();
if (literal.shape().IsTuple() ||
(!hlo->sharding().IsTileMaximal() &&
(!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
!literal.IsAllFirst()))) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
std::vector<int64> start_indices(hlo->shape().rank(), 0);
auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
literal.Slice(start_indices, shard_shape.dimensions())));
*constant->mutable_shape() = shard_shape;
return constant;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) != 1 &&
(hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
!hlo->operand(i + 1)->IsConstant() ||
!hlo->operand(i + 1)->literal().IsZero({}))) {
// We currently do not partition the sliced dimensions.
return DefaultAction(hlo);
}
}
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
auto new_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
.Reshard(HloSharding::Replicate())
.hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
partitioned_shape, new_input, new_indices,
partitioned_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) != 1 &&
(hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) ||
!hlo->operand(i + 2)->IsConstant() ||
!hlo->operand(i + 2)->literal().IsZero({}))) {
// We currently do not partition the sliced dimensions.
return DefaultAction(hlo);
}
}
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
auto new_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
auto new_update =
GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
partitioned_shape, new_input, new_update, new_indices));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
auto gather = Cast<HloGatherInstruction>(hlo);
const auto& dnums = gather->gather_dimension_numbers();
auto operand = GetPartitionedHlo(gather->operand(0));
auto indices = GetPartitionedHlo(gather->operand(1));
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
if (!operand.sharding().IsTileMaximal()) {
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
operand, gather->shape(), collapsed_slice_dims, start_index_map,
offset_dims, gather->gather_slice_sizes());
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough);
std::vector<int64> pslice_sizes(gather->gather_slice_sizes().begin(),
gather->gather_slice_sizes().end());
for (int64 i = 0; i < pslice_sizes.size(); ++i) {
if (operand.sharding().tile_assignment().dim(i) > 1) {
pslice_sizes[i] = operand.hlo()->shape().dimensions(i);
}
}
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes,
gather->indices_are_sorted()));
pgather->set_sharding(*maybe_passthrough);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes(),
num_partitions_) &&
ShapeSizeInBytes(gather->shape()) <
ShapeSizeInBytes(gather->operand(0)->shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
// with full number of indices, where out-of-bounds indices are clamped,
// and masked out with 0 in the result; then we use all-reduce to combine
// results. Although gather will not get faster, we avoided the need to
// replicate the operand.
HloInstruction* indices_min;
HloInstruction* indices_max;
std::tie(indices_min, indices_max) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, partition_id_, start_index_map,
dnums.index_vector_dim(), &b_);
// Clamp the indices.
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary(
indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(),
indices_max));
// Adjust the indices by subtracting the offset.
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.base_shape(), HloOpcode::kSubtract, adjusted_indices,
indices_min));
// Gather on adjusted indices.
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
gather->shape(), operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
// Mask out invalid results.
auto filter = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_min, ComparisonDirection::kLt));
filter = b_.AddInstruction(HloInstruction::CreateBinary(
filter->shape(), HloOpcode::kOr, filter,
b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_max, ComparisonDirection::kGt))));
if (dnums.index_vector_dim() < indices.base_shape().rank()) {
std::vector<int64> reduced_filter_dims;
for (int64 i = 0; i < filter->shape().rank(); ++i) {
if (i != dnums.index_vector_dim()) {
reduced_filter_dims.push_back(filter->shape().dimensions(i));
}
}
filter = b_.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter,
CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()},
MakeBinaryAdd(PRED, module_)));
}
std::vector<int64> batch_dims;
for (int64 i = 0; i < pgather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter,
batch_dims));
auto filtered = b_.AddInstruction(HloInstruction::CreateTernary(
pgather->shape(), HloOpcode::kSelect, broadcast_filter,
CreateZero(pgather->shape(), &b_), pgather));
// Combine from different partitions.
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, filtered,
MakeBinaryAdd(filtered->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
const auto& tuple = GetPartitionedHlo(hlo->operand(0));
auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
tuple.hlo(), hlo->tuple_index()));
SetPartitionedHlo(hlo, [&]() {
const auto source_sharding = tuple.sharding().GetSubSharding(
tuple.base_shape(), {hlo->tuple_index()});
gte->set_sharding(source_sharding);
PartitionedHlo source_partitioned_gte(gte, hlo->shape(),
MakePartitioningState());
return source_partitioned_gte.Reshard(hlo->sharding()).hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
if (ShapeUtil::GetLeafCount(shape) == 0) {
// TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
// requires one element for an empty tuple, but leaf-count number of
// elements for non-empty tuple. So if it has a nested empty tuple, we
// cannot invoke GetSubSharding() since it expects a sharding for the empty
// tuple. This is a workaround for that case.
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(
HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
});
return Status::OK();
}
auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
auto shard_shape = MakePartitionedShape(shape, sharding);
if (EvenlyPartitions(shape, sharding)) {
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateInfeed(
shard_shape, token, hlo->infeed_config()));
});
return Status::OK();
}
if (hlo->sharding().HasUniqueDevice()) {
return HandleSingleDevice(hlo);
}
// Create a branch for each unique partitioned shape.
std::vector<Shape> per_branch_partitioned_shapes;
std::vector<int32> conditional_branch_indices(num_partitions_);
for (int64 i = 0; i < num_partitions_; ++i) {
auto partitioned_shape =
MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
int64 matching_existing_index = 0;
for (; matching_existing_index < per_branch_partitioned_shapes.size();
++matching_existing_index) {
if (ShapeUtil::Compatible(
partitioned_shape,
per_branch_partitioned_shapes[matching_existing_index])) {
break;
}
}
if (matching_existing_index < per_branch_partitioned_shapes.size()) {
conditional_branch_indices[i] = matching_existing_index;
} else {
conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
}
}
HloInstruction* branch_index;
if (per_branch_partitioned_shapes.size() == num_partitions_) {
// Use partition ID as the branch index if each partition has its own
// branch.
branch_index = partition_id_;
// PartitionId's output is U32 but conditional requires S32.
if (branch_index->shape().element_type() != S32) {
branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(branch_index->shape(), S32),
branch_index));
}
} else {
// Otherwise, use a constant table to look up the branch index.
auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
{1}));
branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), branch_index));
}
std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
for (int64 i = 0; i < branches.size(); ++i) {
SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, token->shape(), "infeed_token_param"));
auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed));
if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
TF_ASSIGN_OR_RETURN(
auto padded,
branches[i]->DeepCopyInstructionWithCustomCopier(
infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* comp) {
// Index {1} corresponds to the token.
if (leaf_index.empty() || leaf_index[0] != 0) {
return leaf;
}
ShapeIndexView subindex(leaf_index, 1);
if (ShapeUtil::Compatible(
ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i],
subindex),
ShapeUtil::GetSubshape(shard_shape, subindex))) {
return leaf;
}
return PadToShape(leaf,
ShapeUtil::GetSubshape(shard_shape, subindex),
nullptr, comp);
}));
branches[i]->set_root_instruction(padded,
/*accept_different_shape=*/true);
}
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
branches, std::vector<HloInstruction*>(branches.size(), token)));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
// Create a window config to represent the pad.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
const auto& pd = hlo->padding_config().dimensions(i);
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_padding_low(pd.edge_padding_low());
dim->set_padding_high(pd.edge_padding_high());
dim->set_base_dilation(pd.interior_padding() + 1);
}
auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
.Reshard(HloSharding::Replicate())
.hlo();
auto reshard_operand =
lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
PaddingConfig sharded_padding_config;
bool need_pad = false;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
auto dim = sharded_padding_config.add_dimensions();
const auto& wd = reshard_operand->shard_window.dimensions(i);
dim->set_edge_padding_low(wd.padding_low());
dim->set_edge_padding_high(wd.padding_high());
dim->set_interior_padding(wd.base_dilation() - 1);
if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
wd.base_dilation() != 1) {
need_pad = true;
}
}
auto sharded_pad = reshard_operand->sharded_input;
if (need_pad) {
TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
ShapeInference::InferPadShape(sharded_pad->shape(),
replicated_rhs->shape(),
sharded_padding_config));
sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
sharded_pad, replicated_rhs,
sharded_padding_config));
}
SetPartitionedHlo(hlo, [&]() {
if (!reshard_operand->dynamic_slice_index_on_output) {
return sharded_pad;
}
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_pad,
*reshard_operand->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
hlo->parameter_number(), shard_shape, "param"));
if (hlo->parameter_replicated_at_leaf_buffers()) {
new_param->set_parameter_replicated_at_leaf_buffers(
*hlo->parameter_replicated_at_leaf_buffers());
}
return new_param;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
int64 input_count = 1;
auto per_input_sharding = hlo->sharding();
if (hlo->shape().IsTuple()) {
input_count = hlo->shape().tuple_shapes_size();
CHECK_GT(input_count, 0);
per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
}
std::vector<PartitionedHlo> inputs;
std::vector<HloInstruction*> inits;
for (int64 operand_id = 0; operand_id < input_count; ++operand_id) {
inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
.Reshard(HloSharding::Replicate())
.hlo());
inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
if (operand_id > 0) {
// Make sure all operands are sharded in the same way.
inputs.back() = inputs.back().Reshard(inputs[0].sharding());
}
if (!inputs[0].sharding().IsTileMaximal()) {
inputs.back() = inputs.back().PadWithValue(inits[operand_id]);
}
}
bool reduce_sharded_dimension = false;
if (!inputs[0].sharding().IsTileMaximal()) {
reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) {
return inputs[0].sharding().tile_assignment().dim(i) > 1;
});
// reduce_sharded_dimension is not supported for tuple-shaped reduces.
if (reduce_sharded_dimension && input_count > 1) {
return DefaultAction(hlo);
}
// Currently we only support reducing all or none of the sharded
// dimensions.
if (reduce_sharded_dimension) {
for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) {
if (inputs[0].sharding().tile_assignment().dim(i) > 1 &&
absl::c_count(hlo->dimensions(), i) == 0) {
return DefaultAction(hlo);
}
}
}
}
std::vector<Shape*> new_operand_shapes(input_count * 2);
for (int64 i = 0; i < input_count; ++i) {
new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
}
// Create the shard shape of the reduce result.
TF_ASSIGN_OR_RETURN(
auto reduce_shape,
ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
hlo->to_apply()->ComputeProgramShape()));
*reduce_shape.mutable_layout() = hlo->shape().layout();
std::vector<HloInstruction*> input_hlos(input_count);
for (int64 i = 0; i < input_count; ++i) {
input_hlos[i] = inputs[i].hlo();
}
auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
local_reduce->set_metadata(hlo->metadata());
SetPartitionedHlo(hlo, [&]() {
HloInstruction* reduce;
if (reduce_sharded_dimension) {
CHECK(local_reduce->shape().IsArray());
reduce = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, local_reduce, hlo->to_apply(), NewChannel());
reduce->set_sharding(HloSharding::Replicate());
} else {
reduce = local_reduce;
if (inputs[0].sharding().IsTileMaximal()) {
reduce->set_sharding(inputs[0].sharding());
} else {
// Remove tile assignment dimensions that are reduced.
std::vector<int64> tile_dimensions;
for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) {
if (absl::c_count(hlo->dimensions(), i) == 0) {
tile_dimensions.push_back(
inputs[0].sharding().tile_assignment().dim(i));
}
}
Array<int64> new_tile = inputs[0].sharding().tile_assignment();
new_tile.Reshape(tile_dimensions);
auto sharding = HloSharding::Tile(new_tile);
if (input_count > 1) {
std::vector<HloSharding> tuple(input_count, sharding);
sharding = HloSharding::Tuple(hlo->shape(), tuple);
}
reduce->set_sharding(sharding);
}
}
return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
auto reverse = Cast<HloReverseInstruction>(hlo);
if (reverse->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(reverse->operand(0))
.Reshard(hlo_sharding_util::ReverseSharding(
reverse->sharding(), reverse->dimensions()));
auto left_padded_operand =
HaloExchangeToPadOnLeft(operand, reverse->dimensions());
if (!left_padded_operand) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
left_padded_operand->shape(), {left_padded_operand}));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
// Shardings for the body parameter, body root, and cond parameter must be
// the same, and the condition root must be replicated so that all partitions
// follow the same control flow.
hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_condition(),
HloSharding::Replicate(),
next_channel_id_, logger_)
.status());
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_body(), sharding,
next_channel_id_, logger_)
.status());
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateWhile(
MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
hlo->while_body(),
GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
std::vector<HloInstruction*> branch_args;
for (int64 i = 0; i < hlo->branch_count(); ++i) {
HloComputation* computation = hlo->branch_computation(i);
// Shardings of the branch computation parameter and its argument must be
// the same.
computation->parameter_instruction(0)->set_sharding(
hlo->operand(i + 1)->sharding());
branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
}
// The root of the branch computations must follow the sharding of the
// conditional instruction.
for (int64 i = 0; i < hlo->branch_count(); ++i) {
HloComputation* computation = hlo->branch_computation(i);
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(computation, hlo->sharding(),
next_channel_id_, logger_)
.status());
}
// We replicate the predicate of the conditional (the first operand) so that
// all partitions follow the same control flow.
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateConditional(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
GetPartitionedHlo(hlo->operand(0))
.Reshard(HloSharding::Replicate())
.hlo(),
hlo->called_computations(), branch_args));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
return HandleSingleDevice(hlo);
}
Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
if (hlo->sharding().HasUniqueDevice()) {
return HandleSingleDevice(hlo);
}
if (hlo->sharding().IsReplicated()) {
SetPartitionedHlo(hlo, [&] {
// Run on a single device (0) and distribute the data to all other cores.
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::AssignDevice(0))
.hlo());
}
auto clone = b_.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(HloSharding::AssignDevice(0));
return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
.Reshard(HloSharding::Replicate())
.hlo();
});
return Status::OK();
}
TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
SetPartitionedHlo(hlo, [&] {
// Replicate the operands and run partitioned Rng on all devices.
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::Replicate())
.hlo());
}
return b_.AddInstruction(HloInstruction::CreateRng(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
hlo->random_distribution(), new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
auto& operand = GetPartitionedHlo(hlo->operand(0));
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
// Replicate init
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1))
.Reshard(HloSharding::Replicate());
auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
hlo->window(), hlo->sharding(), replicated_init.hlo());
if (!resharded_operand_and_window.has_value()) {
return DefaultAction(hlo);
}
TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
ShapeInference::InferReduceWindowShape(
resharded_operand_and_window->sharded_input->shape(),
replicated_init.hlo()->shape(),
resharded_operand_and_window->shard_window,
hlo->to_apply()->ComputeProgramShape()));
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
*sharded_rw_shape.mutable_layout() = shard_shape.layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow(
sharded_rw_shape, resharded_operand_and_window->sharded_input,
replicated_init.hlo(), resharded_operand_and_window->shard_window,
hlo->to_apply()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()));
return sharded_rw;
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_rw,
*resharded_operand_and_window->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0));
auto source = GetPartitionedHlo(hlo->mutable_operand(1));
if (hlo->sharding() != operand.sharding()) {
operand = operand.Reshard(hlo->sharding());
}
if (hlo->sharding() != source.sharding()) {
source = source.Reshard(hlo->sharding());
}
// For F32 and BF16 types, we can use NaN padding to workaround the issue with
// low/high padding, since comparison will return false with NaN input.
if (hlo->shape().element_type() != F32 &&
hlo->shape().element_type() != BF16) {
return DefaultAction(hlo);
}
auto select = hlo->called_computations()[0];
auto select_root = select->root_instruction();
if (select_root->opcode() != HloOpcode::kCompare ||
select_root->operand(0)->opcode() != HloOpcode::kParameter ||
select_root->operand(1)->opcode() != HloOpcode::kParameter ||
select_root->operand(0)->parameter_number() +
select_root->operand(1)->parameter_number() !=
1) {
return DefaultAction(hlo);
}
float float_pad_value;
if (select_root->comparison_direction() == ComparisonDirection::kGe ||
select_root->comparison_direction() == ComparisonDirection::kGt) {
if (select_root->operand(0)->parameter_number() == 0) {
float_pad_value = -std::numeric_limits<float>::infinity();
} else {
float_pad_value = std::numeric_limits<float>::infinity();
}
} else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
select_root->comparison_direction() == ComparisonDirection::kLt) {
if (select_root->operand(0)->parameter_number() == 0) {
float_pad_value = std::numeric_limits<float>::infinity();
} else {
float_pad_value = -std::numeric_limits<float>::infinity();
}
} else {
return DefaultAction(hlo);
}
auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
hlo->shape().element_type() == BF16
? LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(float_pad_value))
: LiteralUtil::CreateR0<float>(float_pad_value)));
// Replicate init
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
.Reshard(HloSharding::Replicate());
auto partition_ordinals =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
// The first window for each dimension that overlaps with the shard area.
std::vector<MultiplyAddDivideOffsetCalculation> first_window(
hlo->shape().rank());
// The first window for each dimension that goes beyond with the shard area.
std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
hlo->shape().rank());
std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
auto unpadded_data_shard_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
auto unpadded_source_shard_shape =
MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
auto source_shard_hlo = source.hlo();
auto data_shard_hlo = operand.hlo();
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
int64 shard_count = hlo->sharding().tile_assignment().dim(i);
if (shard_count == 1) {
continue;
}
// If stride > window_size, there will be gaps between windows. These gaps
// will also exist in the output, so we keep them during halo exchange.
//
// TODO(yuanzx): This could introduce overhead if partitions start at
// different offsets in a gap.
auto wd = hlo->window().dimensions(i);
if (wd.stride() > wd.size()) {
wd.set_size(wd.stride());
}
// shard_size * i < stride * k - pad_low + window_size =>
// k > (shard_size * i + pad_low - window_size) / stride =>
// first_k == (shard_size * i + pad_low - window_size + stride) / stride
first_window[i] = MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
wd.padding_low() - wd.size() + wd.stride(), wd.stride());
// shard_size * (i + 1) <= stride * k - pad_low =>
// k >= (shard_size * i + shard_size + pad_low) / stride =>
// limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
// stride
limit_window[i] = MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
wd.stride() - 1,
wd.stride());
source_left_halo_sizes[i] =
MultiplyAddDivideOffsetCalculation(
unpadded_source_shard_shape.dimensions(i), 0, 1) -
first_window[i];
source_right_halo_sizes[i] =
limit_window[i] - MultiplyAddDivideOffsetCalculation(
unpadded_source_shard_shape.dimensions(i),
unpadded_source_shard_shape.dimensions(i), 1);
data_left_halo_sizes[i] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
OffsetCalculation(
HloOpcode::kMultiply, first_window[i],
MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
data_right_halo_sizes[i] =
OffsetCalculation(
HloOpcode::kMultiply, limit_window[i],
MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
unpadded_data_shard_shape.dimensions(i) + wd.stride() +
wd.padding_low() - wd.size(),
1));
int64 max_windows =
(limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
auto first_window_hlo =
first_window[i].Calculate(partition_ordinals[i], &b_);
// Padding on the source is filled with the init value so they do not change
// the data on overlapping windows.
auto resharded_source = ExchangeHaloAndGetValidData(
source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
source_right_halo_sizes[i], 0,
limit_window[i].Calculate(shard_count - 1), max_windows, i,
hlo->sharding(), first_window_hlo, replicated_init.hlo(),
partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
if (!resharded_source) {
return DefaultAction(hlo);
}
source_shard_hlo = *resharded_source;
auto offset_start_in_data =
MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
.Calculate(first_window_hlo, &b_);
int64 padded_data_size =
(limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
wd.size();
int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
auto resharded_data = ExchangeHaloAndGetValidData(
data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
if (!resharded_data) {
return DefaultAction(hlo);
}
data_shard_hlo = *resharded_data;
}
Window window_on_shard = hlo->window();
for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
int64 shard_count = hlo->sharding().tile_assignment().dim(i);
if (shard_count == 1) {
continue;
}
auto reshard_wd = window_on_shard.mutable_dimensions(i);
// The shards are already explicitly padded.
reshard_wd->set_padding_low(0);
reshard_wd->set_padding_high(0);
}
auto sharded_select_and_scatter =
b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
source_shard_hlo, replicated_init.hlo(),
hlo->called_computations()[1]));
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
shard_shape)) {
return sharded_select_and_scatter;
}
auto zero = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) == 1) {
continue;
}
int64 pad_low = hlo->window().dimensions(i).padding_low();
auto left_halo_size =
data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
slice_offsets[i] = left_halo_size;
} else {
auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
ComparisonDirection::kEq));
auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(pad_low)));
slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
left_halo_size));
}
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_select_and_scatter, slice_offsets,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i))
.Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
.hlo());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
HloInstruction* hlo) {
TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution);
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
const auto& dnums = hlo->convolution_dimension_numbers();
// Check if the operand shardings are aligned. Also we currently don't
// support partitioning non-spatial dimensions.
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
dnums.input_batch_dimension();
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
Window window = hlo->window();
std::vector<int64> reversed_rhs_dims;
for (int64 i = 0; i < window.dimensions_size(); ++i) {
if (window.dimensions(i).window_reversal()) {
reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
}
}
if (!reversed_rhs_dims.empty()) {
// Make the reversed dims left-padded to prepare for window reversal.
auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
if (left_padded_rhs == nullptr) {
return DefaultAction(hlo);
}
left_padded_rhs->set_sharding(rhs.sharding());
rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
}
// Consider window reversal when resharding RHS or LHS. Note: this will not
// reverse the data in the shard. We use window reversal to do that.
auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
reversed_rhs_dims);
auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
lhs_to_rhs_indices);
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
const HloSharding& rhs_sharding) {
return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
1 ||
rhs_sharding.tile_assignment().dim(
dnums.kernel_output_feature_dimension()) != 1;
};
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
} else {
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
return DefaultAction(hlo);
}
lhs = lhs.PadWithValue(zero);
rhs =
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
}
// Reshard LHS by exchanging halo such that each shard computes the partial
// sum of the full shape result, and add AllReduce.
//
// The size of halo on each dimension can be calculated from the projection
// onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
// to the shard size of RHS and LHS, WC is the number of windows, and D is the
// window dilation.
//
// * offset(i): RHS * D * i - low_padding
// * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding
//
// Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
// * left-halo: i * LHS - offset(i)
// = (LHS - RHS) * i + low_padding
// * right-halo: limit(i) - (i + 1) * LHS
// = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
auto wd = window.dimensions(i);
if (wd.base_dilation() != 1) {
return DefaultAction(hlo);
}
int64 lhs_shard_size =
CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
int64 rhs_shard_size =
CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
shard_counts[i] = shard_count;
lhs_shard_sizes[i] = lhs_shard_size;
rhs_shard_sizes[i] = rhs_shard_size;
}
std::vector<OffsetCalculation> left_halo_size_functions(hlo->shape().rank());
std::vector<OffsetCalculation> right_halo_size_functions(hlo->shape().rank());
Window new_window = window;
auto partition_ordinals =
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
HloInstruction* lhs_with_halo = lhs.hlo();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 lhs_shard_size = lhs_shard_sizes[i];
int64 rhs_shard_size = rhs_shard_sizes[i];
if (shard_counts[i] == 1) {
continue;
}
// Calculate the left and right halo sizes as described in the comments
// above.
auto wd = window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
int64 window_count = 1 + (padding_low + padding_high + base -
(1 + (wd.size() - 1) * wd.window_dilation())) /
wd.stride();
int64 rhs_shard_size_dilated =
(rhs_shard_size - 1) * wd.window_dilation() + 1;
left_halo_size_functions[lhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
1));
right_halo_size_functions[lhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size_dilated - lhs_shard_size,
rhs_shard_size_dilated - lhs_shard_size +
wd.stride() * (window_count - 1) - padding_low,
1));
// Exchange halo and concatenate.
int64 dim = dnums.input_spatial_dimensions(i);
int64 explicit_left_padding_on_full_shape = padding_low;
int64 shard_size_with_halo =
wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
new_window.mutable_dimensions(i)->set_padding_low(0);
new_window.mutable_dimensions(i)->set_padding_high(0);
new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
// offset_on_padded_shape and padded_full_shape_size are needed only if
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
// Since the default value for both the collective-permute is zero and
// also we call PadWithValue() on both operands at the beginning, we
// don't need to mask here.
//
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
// if it's always safe.
auto offset_on_padded_shape =
OffsetCalculation(MultiplyAddDivideOffsetCalculation());
int64 padded_full_shape_size = 0;
auto concat = ExchangeHaloAndGetValidData(
lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero,
partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_,
/*mask_invalid_region=*/false);
if (!concat) {
return DefaultAction(hlo);
}
lhs_with_halo = *concat;
}
SetPartitionedHlo(hlo, [&]() {
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window,
hlo->convolution_dimension_numbers(), hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
if (dot_dnums) {
// Use HandleDotHelper() for convs that are actually einsums.
spmd::DotGeneralDimsMapping mapping;
for (const auto& dims : dot_dnums->batch_dims) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dims.lhs;
mapping.batch_dims.back().rhs = dims.rhs;
mapping.batch_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->contracting_dims) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dims.lhs;
mapping.contracting_dims.back().rhs = dims.rhs;
mapping.contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.lhs_non_contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.rhs_non_contracting_dims.back().output = dims.output;
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
const HloSharding& sharding = hlo->sharding();
const auto& dnums = hlo->convolution_dimension_numbers();
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
dnums.input_batch_dimension();
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
auto aligned_rhs_sharding =
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
auto aligned_lhs_sharding =
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
// Handling cases where both operands' shardings are aligned. We check that
// the LHS batch dimension is not partitioned because it is mapped to the
// output feature dimension in aligned_rhs_sharding, which are not the same
// dimension.
if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
if (options_.conv_halo_exchange_always_on_lhs) {
return HandleConvolutionTiledLhsAndRhs(hlo);
} else {
// Reshard RHS so that each shard computes the partial sum of the full
// shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
// that reshards LHS.
//
// The size of halo on each dimension can be calculated from the
// projection onto the RHS that shard i needs to read. RHS and LHS below
// refers to the shard size of RHS and LHS, WC is the number of windows,
// and D is the window dilation.
//
// * offset(i): LHS * i + low_padding - (WC - 1) * stride
// * limit(i): LHS * (i + 1) + low_padding
//
// Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
// * left-halo: i * RHS - offset(i)
// = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
// * right-halo: limit(i) - (i + 1) * RHS
// = (i + 1) * (LHS - RHS * D) + low_pading
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
const HloSharding& rhs_sharding) {
// We currently don't support partitioning input batch or output feature
// dimensions.
return lhs_sharding.tile_assignment().dim(
dnums.input_batch_dimension()) != 1 ||
rhs_sharding.tile_assignment().dim(
dnums.kernel_output_feature_dimension()) != 1;
};
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeSizeInBytes(lhs.base_shape()) <
ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
} else {
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
return DefaultAction(hlo);
}
lhs = lhs.PadWithValue(zero);
rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
}
Window window = hlo->window();
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
auto wd = window.dimensions(i);
if (wd.base_dilation() != 1 || wd.window_reversal()) {
return DefaultAction(hlo);
}
int64 lhs_shard_size = CeilOfRatio(
lhs.base_shape().dimensions(lhs_dimension), shard_count);
int64 rhs_shard_size = CeilOfRatio(
rhs.base_shape().dimensions(rhs_dimension), shard_count);
shard_counts[i] = shard_count;
lhs_shard_sizes[i] = lhs_shard_size;
rhs_shard_sizes[i] = rhs_shard_size;
}
std::vector<OffsetCalculation> left_halo_size_functions(
hlo->shape().rank());
std::vector<OffsetCalculation> right_halo_size_functions(
hlo->shape().rank());
Window new_window = window;
// Data structures needed for Pad and DynamicSlice on LHS if needed.
bool need_dynamic_slice_lhs = false;
auto partition_ordinals =
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
std::vector<int64> zero_padding(hlo->shape().rank());
PaddingConfig pad_config =
window_util::MakeSymmetricPadding(zero_padding);
auto zero_s32 = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> dynamic_slice_start_indices(
hlo->shape().rank(), zero_s32);
Shape dynamic_slice_shape = lhs.hlo()->shape();
Shape pad_shape = lhs.hlo()->shape();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 lhs_shard_size = lhs_shard_sizes[i];
int64 rhs_shard_size = rhs_shard_sizes[i];
if (shard_counts[i] == 1) {
continue;
}
// Calculate the left and right halo sizes as described in the comments
// above. It calculcates the halo sizes with dilation, so we apply
// CeilOfRatio({left,right}_halo_size, window_dilation).
auto wd = window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
int64 window_count =
1 + (padding_low + padding_high + base -
(1 + (wd.size() - 1) * wd.window_dilation())) /
wd.stride();
left_halo_size_functions[rhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
(window_count - 1) * wd.stride() - padding_low +
wd.window_dilation() - 1,
wd.window_dilation()));
right_halo_size_functions[rhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
lhs_shard_size - rhs_shard_size * wd.window_dilation(),
lhs_shard_size - rhs_shard_size * wd.window_dilation() +
padding_low + wd.window_dilation() - 1,
wd.window_dilation()));
// New RHS window size includes the maximum of both left and right
// halos.
int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange(
1, shard_counts[i]) +
right_halo_size_functions[rhs_dimension].MaxInRange(
0, shard_counts[i] - 1);
int64 new_window_size =
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
// The amount of new low padding could be dynamic (e.g., window_dilation
// != 1), which requires pad (to the maximum) and dynamic slice on LHS.
//
// If we consider the first window, the offset of the dilated RHS that
// aligns with the first valid LHS element for shard i is 'padding_low +
// LHS * i'. When the left halo is added to RHS, the offset of the first
// RHS element is (RHS * i - left_halo) * window_dilation. The
// difference between the two values is the amount of padding_low we
// need on LHS.
auto new_padding_low_function =
OffsetCalculation(
HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension],
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
0, wd.window_dilation(), 1))) -
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
-padding_low, 1));
int64 new_padding_low_max =
new_padding_low_function.MaxInRange(0, shard_counts[i]);
int64 new_padding_low = new_padding_low_max;
int64 new_padding_high = window_count * wd.stride() +
(new_window_size - 1) * wd.window_dilation() -
new_padding_low - lhs_shard_size;
// We do pad/dynamic-slice only when the padding is dynamic.
if (!new_padding_low_function.IsConstant()) {
need_dynamic_slice_lhs = true;
new_padding_low = 0;
pad_config.mutable_dimensions(lhs_dimension)
->set_edge_padding_low(new_padding_low_max);
pad_config.mutable_dimensions(lhs_dimension)
->set_edge_padding_high(new_padding_low_max);
pad_shape.set_dimensions(lhs_dimension,
lhs_shard_size + 2 * new_padding_low_max);
dynamic_slice_start_indices[lhs_dimension] =
(OffsetCalculation(MultiplyAddDivideOffsetCalculation(
0, new_padding_low_max, 1)) -
new_padding_low_function)
.Calculate(partition_ordinals[lhs_dimension], &b_);
dynamic_slice_shape.set_dimensions(
lhs_dimension, lhs_shard_size + new_padding_low_max);
}
// Since the convolution RHS operand size increased with halos, adjust
// the window config accordingly.
new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
new_window.mutable_dimensions(i)->set_size(
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
}
HloInstruction* conv_lhs = lhs.hlo();
if (need_dynamic_slice_lhs) {
auto pad = b_.AddInstruction(
HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
dynamic_slice_shape, pad, dynamic_slice_start_indices,
dynamic_slice_shape.dimensions()));
}
// Exchange halo and concatenate.
HloInstruction* rhs_with_halo = rhs.hlo();
for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
int64 dim = dnums.kernel_spatial_dimensions(i);
int64 explicit_left_padding_on_full_shape =
left_halo_size_functions[dim].Calculate(0);
int64 shard_size_with_halo = new_window.dimensions(i).size();
// offset_on_padded_shape and padded_full_shape_size are needed only if
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
// Since the default value for both the collective-permute is zero and
// also we call PadWithValue() on both operands at the beginning, we
// don't need to mask here.
//
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
// if it's always safe.
auto offset_on_padded_shape =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
left_halo_size_functions[dim];
int64 padded_full_shape_size =
offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
new_window.dimensions(i).size();
auto concat = ExchangeHaloAndGetValidData(
rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_),
zero, partition_ordinals[dim], collective_ops_creator_,
next_channel_id_, &b_, /*mask_invalid_region=*/false);
if (!concat) {
return DefaultAction(hlo);
}
rhs_with_halo = *concat;
}
SetPartitionedHlo(hlo, [&]() {
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums,
hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
if (!sharding.IsTileMaximal()) {
// We don't currently support sharding on output feature dimension.
if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) {
return DefaultAction(hlo);
}
// Check if the operand and the output sharding are aligned.
std::vector<int64> input_to_output_indices(hlo->shape().rank());
input_to_output_indices[dnums.input_batch_dimension()] =
dnums.output_batch_dimension();
input_to_output_indices[dnums.input_feature_dimension()] =
dnums.output_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
input_to_output_indices[dnums.input_spatial_dimensions(i)] =
dnums.output_spatial_dimensions(i);
}
auto target_operand_sharding =
hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices);
lhs = lhs.Reshard(target_operand_sharding);
// Replicate the RHS.
rhs = rhs.Reshard(HloSharding::Replicate());
// Convolution window config does not include batch and feature dimensions,
// whereas ReshardAsWindowedInput() expects the same number of window
// dimensions as the rank of the operand. So add two more trivial
// dimensions.
std::vector<int64> ones(hlo->shape().rank(), 1);
auto operand_window = window_util::MakeWindow(ones);
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
*operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
hlo->window().dimensions(i);
}
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
auto resharded_operand_and_window = lhs.ReshardAsWindowedInput(
operand_window, target_operand_sharding, zero);
if (!resharded_operand_and_window.has_value()) {
return DefaultAction(hlo);
}
Window new_window;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
*new_window.add_dimensions() =
resharded_operand_and_window->shard_window.dimensions(
dnums.input_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
resharded_operand_and_window->sharded_input->shape(),
rhs.hlo()->shape(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums));
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
*sharded_conv_shape.mutable_layout() = shard_shape.layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, resharded_operand_and_window->sharded_input,
rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(),
new_window, dnums, hlo->precision_config()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
return sharded_conv;
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_conv,
*resharded_operand_and_window->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
DotGeneralDimsMapping mapping;
const auto& dnums = hlo->dot_dimension_numbers();
int64 next_output_dim = 0;
for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
mapping.batch_dims.back().output = next_output_dim++;
}
for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
mapping.contracting_dims.back().output = -1;
}
for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
continue;
}
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = i;
mapping.lhs_non_contracting_dims.back().rhs = -1;
mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
}
for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
continue;
}
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = -1;
mapping.rhs_non_contracting_dims.back().rhs = i;
mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
}
auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r,
SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_dot_shape,
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
hlo->dot_dimension_numbers()));
return b->AddInstruction(HloInstruction::CreateDot(
sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
hlo->precision_config()));
};
return HandleDotHelper(hlo, mapping, create_sharded_dot);
}
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
const HloSharding& lhs_sharding = hlo->operand(0)->sharding();
const HloSharding& rhs_sharding = hlo->operand(1)->sharding();
// Similar to hlo_sharding_util::TransposeSharding(), but allows
// removing/adding non-partitioned dimensions.
auto transpose_sharding =
[&](const HloSharding& source, absl::Span<int64 const> src_to_tgt,
absl::Span<int64 const> tgt_to_src) -> absl::optional<HloSharding> {
if (source.IsTileMaximal()) {
return source;
}
std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
int64 skipped_tgt_dims = 0;
for (int64 i = 0; i < tgt_to_src.size(); ++i) {
if (tgt_to_src[i] < 0) {
skipped_tgt_dims++;
} else {
tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
}
}
int64 skipped_src_dims = absl::c_count(src_to_tgt, -1);
std::vector<int64> perm(src_to_tgt.size());
for (int64 i = 0; i < src_to_tgt.size(); ++i) {
if (src_to_tgt[i] < 0) {
if (source.tile_assignment().dim(i) > 1) {
return absl::nullopt;
}
perm[src_to_tgt.size() - skipped_src_dims] = i;
skipped_src_dims--;
} else {
perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
}
}
auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
if (skipped_tgt_dims == 0) {
return tgt_sharding;
}
auto reshape_tiles = tgt_sharding.tile_assignment();
std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
for (int64 i = 0; i < tgt_tiles.size(); ++i) {
if (tgt_to_src[i] >= 0) {
tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
}
}
reshape_tiles.Reshape(tgt_tiles);
return HloSharding::Tile(reshape_tiles);
};
std::vector<int64> lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1);
std::vector<int64> lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1);
std::vector<int64> rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1);
std::vector<int64> rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1);
std::vector<int64> output_to_lhs_indices(hlo->shape().rank(), -1);
std::vector<int64> output_to_rhs_indices(hlo->shape().rank(), -1);
auto populate_indices_mapping =
[&](const DotGeneralDimsMapping::DimsMapping& mapping) {
if (mapping.lhs >= 0) {
lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
lhs_to_output_indices[mapping.lhs] = mapping.output;
}
if (mapping.rhs >= 0) {
rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
rhs_to_output_indices[mapping.rhs] = mapping.output;
}
if (mapping.output >= 0) {
output_to_lhs_indices[mapping.output] = mapping.lhs;
output_to_rhs_indices[mapping.output] = mapping.rhs;
}
};
for (const auto& mapping : dims_mapping.batch_dims) {
populate_indices_mapping(mapping);
}
for (const auto& mapping : dims_mapping.contracting_dims) {
populate_indices_mapping(mapping);
}
for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
populate_indices_mapping(mapping);
}
for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
populate_indices_mapping(mapping);
}
auto lhs_sharding_transposed_to_match_rhs =
transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices);
auto rhs_sharding_transposed_to_match_lhs =
transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices);
auto lhs_sharding_transposed_to_match_output = transpose_sharding(
lhs_sharding, lhs_to_output_indices, output_to_lhs_indices);
auto rhs_sharding_transposed_to_match_output = transpose_sharding(
rhs_sharding, rhs_to_output_indices, output_to_rhs_indices);
auto output_sharding_transposed_to_match_lhs = transpose_sharding(
hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices);
auto output_sharding_transposed_to_match_rhs = transpose_sharding(
hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices);
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
auto get_partitions_for_dims =
[&](const HloSharding& sharding,
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
int lhs_rhs_or_output) {
int64 partitions = 1;
if (sharding.IsTileMaximal()) {
return partitions;
}
for (const auto& dim : dims) {
if (lhs_rhs_or_output == 0) {
partitions *= sharding.tile_assignment().dim(dim.lhs);
} else if (lhs_rhs_or_output == 1) {
partitions *= sharding.tile_assignment().dim(dim.rhs);
} else {
CHECK_EQ(lhs_rhs_or_output, 2);
partitions *= sharding.tile_assignment().dim(dim.output);
}
}
return partitions;
};
const int64 lhs_batch_partitions =
get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0);
const int64 rhs_batch_partitions =
get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1);
const int64 output_batch_partitions =
get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2);
const int64 lhs_contracting_partitions =
get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0);
const int64 rhs_contracting_partitions =
get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1);
const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0);
const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1);
const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2);
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2);
auto& lhs = GetPartitionedHlo(hlo->operand(0));
auto& rhs = GetPartitionedHlo(hlo->operand(1));
// LHS and RHS are partitioned the same way and only partitioned in batch
// dimensions.
if (lhs_batch_partitions == rhs_batch_partitions &&
rhs_batch_partitions == num_partitions_ &&
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
return PartitionedHlo(dot, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
// Try emit batch-partitioned einsum with one operand resharded. Returns
// whether the attempt succeeds. If may_reshard_with_allreduce is false,
// reshard must be done using all-to-all; otherwise this attempt fails.
auto try_emit_output_batch_partitioned_einsum_with_reshard =
[&](bool may_reshard_with_allreduce) -> StatusOr<bool> {
// LHS and output are batch partitioned in the same way.
if (lhs_batch_partitions == num_partitions_ &&
output_batch_partitions == num_partitions_ &&
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (!may_reshard_with_allreduce &&
!GetReshardAllToAllSourceTargetDims(
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
return false;
}
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return true;
}
// RHS and output are batch partitioned in the same way.
if (rhs_batch_partitions == num_partitions_ &&
output_batch_partitions == num_partitions_ &&
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (!may_reshard_with_allreduce &&
!GetReshardAllToAllSourceTargetDims(
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
return false;
}
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return true;
}
return false;
};
{
// Try batch-parallel by resharding one operand, and not using all-reduce.
TF_ASSIGN_OR_RETURN(
bool emitted,
try_emit_output_batch_partitioned_einsum_with_reshard(false));
if (emitted) {
return Status::OK();
}
}
// Try to emit windowed DotGeneral when one operand is partitioned in the same
// way as the output along non-contracting dimensions, but the other operand
// is tiled in other dimensions.
auto emit_windowed_dot_general = [&](int64 matching_operand,
int64 windowing_operand,
bool windowed_at_contracting_dims,
bool windowed_at_batch_dims) {
CHECK_EQ(matching_operand + windowing_operand, 1);
CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims);
auto unpadded_result_buffer_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
auto padded_result_buffer_shape = unpadded_result_buffer_shape;
// For windowing at batch/non-contracting dims, we produce the result one
// partition at a time, so we need to pad the shape in case of uneven
// partitioning in order to make dynamic-update-slice in-bound.
if (!windowed_at_contracting_dims) {
padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
padded_result_buffer_shape,
windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
: *rhs_sharding_transposed_to_match_output);
}
// Mask the padding area of the windowed operand with zero if there is
// uneven partitioning.
if (windowed_at_contracting_dims) {
auto& to_mask = windowing_operand == 0 ? lhs : rhs;
to_mask =
to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type()))));
}
auto result_buffer = CreateZero(padded_result_buffer_shape, &b_);
auto iteration = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
// Create a while loop that computes one window per iteration. During each
// iteration, each partition sends its input window to its neighbor using
// collective-permute for the next iteration.
SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_);
auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
result_buffer->shape(), iteration->shape()}),
"param"));
auto l = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0));
auto r = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1));
auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
result_buffer->shape(), param, 2));
auto i = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3));
auto partition_id = collective_ops_creator_.create_partition_id(&body_b);
auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kAdd, i, partition_id));
auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_)));
data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count));
auto dot_lhs = l;
auto dot_rhs = r;
if (windowed_at_contracting_dims || windowed_at_batch_dims) {
// Slice the matching operand according to the partitioned contracting
// dimensions on the windowed operand. We do this by treating the matching
// operand as replicated, and resharding it to match the windowed operand.
auto slice_operand = matching_operand == 0 ? l : r;
slice_operand->set_sharding(HloSharding::Replicate());
auto state = MakePartitioningState();
state.b = &body_b;
state.partition_id = data_partition_id;
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
.Reshard(windowing_operand == 0
? *lhs_sharding_transposed_to_match_rhs
: *rhs_sharding_transposed_to_match_lhs)
.hlo();
slice_operand->clear_sharding();
if (matching_operand == 0) {
dot_lhs = slice;
} else {
dot_rhs = slice;
}
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(dot_lhs, dot_rhs, &body_b));
if (windowed_at_contracting_dims) {
// Accumulate the partial output to the result buffer.
o = body_b.AddInstruction(
HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
} else {
// The windowing operand is partitioned along batch/non-contracting
// dimensions, so we need a dynamic-update-slice to save the partial
// output in the result buffer.
auto offsets = MakePartitionOffsets(
o->shape(),
windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output
: *rhs_sharding_transposed_to_match_output,
data_partition_id, &body_b);
o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
o->shape(), o, dot, offsets));
}
// ++i
i = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kAdd, i,
body_b.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)))));
auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), i,
body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_))),
ComparisonDirection::kLt));
// Collective-permute for the next window. We don't need it for the last
// iteration, so we use a conditional around the collective-permute.
HloInstruction* conditional;
{
SpmdBuilder cp_b("window_collective_permute", visiting_hlo_);
{
auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions_);
for (int64 source = 0; source < num_partitions_; ++source) {
// 0 -> n-1, 1 -> 0, 2 -> 1, ...
sd_pairs[source] = {source,
(source - 1 + num_partitions_) % num_partitions_};
}
collective_ops_creator_.create_cross_partition_collective_permute(
&cp_b, p, sd_pairs, (*next_channel_id_)++);
}
SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_);
{
ncp_b.AddInstruction(HloInstruction::CreateParameter(
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
}
conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
windowing_operand == 0 ? l->shape() : r->shape(), has_more,
windowing_operand == 0 ? l : r,
module_->AddEmbeddedComputation(cp_b.Build()),
windowing_operand == 0 ? l : r,
module_->AddEmbeddedComputation(ncp_b.Build())));
}
if (windowing_operand == 0) {
l = conditional;
} else {
r = conditional;
}
body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i}));
SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_);
auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
result_buffer->shape(), iteration->shape()}),
"param"));
auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
iteration->shape(), cond_param, 3));
cond_b.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), cond_i,
cond_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_))),
ComparisonDirection::kLt));
auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile(
cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()),
module_->AddEmbeddedComputation(body_b.Build()),
b_.AddInstruction(HloInstruction::CreateTuple(
{lhs.hlo(), rhs.hlo(), result_buffer, iteration}))));
windowed_dot_general_loops_.push_back({while_loop, windowing_operand,
windowed_at_contracting_dims,
windowed_at_batch_dims});
SetPartitionedHlo(hlo, [&] {
auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
result_buffer->shape(), while_loop, 2));
if (!ShapeUtil::Compatible(padded_result_buffer_shape,
unpadded_result_buffer_shape)) {
result = b_.AddInstruction(HloInstruction::CreateSlice(
unpadded_result_buffer_shape, result,
std::vector<int64>(padded_result_buffer_shape.rank(), 0),
unpadded_result_buffer_shape.dimensions(),
std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
}
return result;
});
return Status::OK();
};
if (output_lhs_non_contracting_partitions == num_partitions_ &&
output_sharding_transposed_to_match_lhs == lhs_sharding &&
ShapeSizeInBytes(hlo->operand(1)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (rhs_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(0, 1, true, false);
}
if (rhs_non_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(0, 1, false, false);
}
if (rhs_batch_partitions == num_partitions_) {
return emit_windowed_dot_general(0, 1, false, true);
}
}
if (output_rhs_non_contracting_partitions == num_partitions_ &&
output_sharding_transposed_to_match_rhs == rhs_sharding &&
ShapeSizeInBytes(hlo->operand(0)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (lhs_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(1, 0, true, false);
}
if (lhs_non_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(1, 0, false, false);
}
if (lhs_batch_partitions == num_partitions_) {
return emit_windowed_dot_general(1, 0, false, true);
}
}
{
// Try batch-parallel by resharding one operand, and allowing all-reduce.
TF_ASSIGN_OR_RETURN(
bool emitted,
try_emit_output_batch_partitioned_einsum_with_reshard(true));
if (emitted) {
return Status::OK();
}
}
// LHS and RHS have the same partitioned contracting dimensions.
if (lhs_contracting_partitions == rhs_contracting_partitions &&
lhs_contracting_partitions == num_partitions_) {
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
// Pad both sides with zero, since NaN at one side cannot be masked by zero
// on the other side.
if (ShapeSizeInBytes(lhs.base_shape()) <
ShapeSizeInBytes(rhs.base_shape())) {
lhs =
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
} else {
lhs = lhs.PadWithValue(zero);
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
// LHS and output have the same partitioned non-contracting dimensions.
if (lhs_non_contracting_partitions == num_partitions_ &&
output_lhs_non_contracting_partitions == num_partitions_ &&
lhs_sharding == hlo->sharding()) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs_replicated, &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
}
// RHS and output have the same partitioned non-contracting dimensions.
if (rhs_non_contracting_partitions == num_partitions_ &&
output_rhs_non_contracting_partitions == num_partitions_ &&
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs_replicated, rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
}
// Output is batch partitioned.
if (output_batch_partitions == num_partitions_) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
}
// Output is partitioned along LHS non-contracting dimensions.
if (output_lhs_non_contracting_partitions == num_partitions_) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions_) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
}
// Returns true if it is beneficial to reshard the operand at `operand_idx`
// across the contracting dimension.
const auto should_partition_contracting_dim = [&](int64 operand_idx) {
if (!hlo->sharding().IsReplicated()) {
return false;
}
if (operand_idx == 0) {
// If LHS and output are replicated, we compare the cost of all-gather
// on RHS vs all-reduce on the output.
return (rhs_contracting_partitions == num_partitions_) &&
lhs.sharding().IsReplicated() &&
ShapeUtil::ElementsIn(hlo->operand(1)->shape()) >
ShapeUtil::ElementsIn(hlo->shape());
} else {
return (lhs_contracting_partitions == num_partitions_) &&
rhs.sharding().IsReplicated() &&
ShapeUtil::ElementsIn(hlo->operand(0)->shape()) >
ShapeUtil::ElementsIn(hlo->shape());
}
};
// When the output is replicated and one of the operands is partitioned along
// contracting dimension, align the other operand to be partitioned along
// the contracting dimensions.
if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) ||
should_partition_contracting_dim(1))) {
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (should_partition_contracting_dim(0)) {
lhs =
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
} else {
lhs = lhs.PadWithValue(zero);
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo();
});
return Status::OK();
}
return DefaultAction(hlo);
}
namespace {
// Finds a cluster of nodes that produce the inputs for `hlo` which only depend
// on small operands, which means the cluster should start with broadcasts,
// constants and iotas. All other internal nodes must be non-side-effecting
// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for
// the following graph,
//
// a -> broadcast -> multiply
// iota ---> add--/
// constant/
//
// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
// <{broadcast, iota, constant, add, multiply}, [a]>.
std::pair<std::unordered_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
std::unordered_set<HloInstruction*> nodes_found;
std::vector<HloInstruction*> new_operands;
std::unordered_set<const HloInstruction*> new_operands_set;
std::vector<HloInstruction*> worklist;
worklist.push_back(hlo);
while (!worklist.empty()) {
auto inst = worklist.back();
worklist.pop_back();
if (nodes_found.count(inst) > 0) {
continue;
}
if (inst->opcode() == HloOpcode::kBroadcast ||
inst->opcode() == HloOpcode::kConstant ||
inst->opcode() == HloOpcode::kIota) {
nodes_found.insert(inst);
for (auto o : inst->operands()) {
auto res = new_operands_set.emplace(o);
if (res.second) {
new_operands.push_back(o);
}
}
} else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
inst->opcode() != HloOpcode::kAllReduce &&
absl::c_all_of(inst->operands(),
[inst](const HloInstruction* o) {
return ShapeUtil::CompatibleIgnoringElementType(
o->shape(), inst->shape());
})) {
nodes_found.insert(inst);
for (auto o : inst->operands()) {
worklist.push_back(o);
}
} else {
nodes_found.clear();
new_operands.clear();
break;
}
}
return {std::move(nodes_found), std::move(new_operands)};
}
// Moves a cluster of memory-reducing nodes into the windowed dot-general loop
// on contracting dimensions. Such a loop has a dynamic slice on the
// non-windowed operand. If we move the input nodes into the loop, the
// dynamic-slice could be merged with them by later optimization passes, which
// reduces memory.
//
// small_operands small_operands
// | |
// input_nodes loop { |
// | => input_nodes
// loop { | |
// dynamic-slice dynamic-slice
// ... ...
// } }
//
// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
// with the input nodes.
Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
HloInstruction* loop, int64 non_windowed_operand_index) {
auto input_tuple = loop->mutable_operand(0);
auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
auto to_sink = std::move(input_nodes.first);
auto new_operands = std::move(input_nodes.second);
if (to_sink.empty()) {
return Status::OK();
}
auto computation = loop->parent();
// Replace the old operand with a tuple of the found small operands.
auto new_input_subtuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
non_windowed_operand_index, new_input_subtuple));
auto body = loop->while_body();
auto body_param = body->parameter_instruction(0);
auto old_body_param_users = body_param->users();
// Update all tuple shapes.
for (auto tuple : std::vector<HloInstruction*>{
input_tuple, loop, loop->while_condition()->parameter_instruction(0),
body_param, body->root_instruction()}) {
*ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
{non_windowed_operand_index}) =
new_input_subtuple->shape();
}
// Now update the loop body.
auto new_operand_tuple_inside =
body->AddInstruction(HloInstruction::CreateGetTupleElement(
new_input_subtuple->shape(), body_param, non_windowed_operand_index));
TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
non_windowed_operand_index, new_operand_tuple_inside));
// Create nodes inside the loop body.
std::vector<HloInstruction*> worklist;
std::unordered_map<const HloInstruction*, HloInstruction*> outside_to_inside;
auto add_users_if_available = [&](HloInstruction* inst) {
for (auto u : inst->users()) {
if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
return outside_to_inside.count(o) > 0;
})) {
worklist.push_back(u);
}
}
};
for (int64 i = 0; i < new_operands.size(); ++i) {
outside_to_inside[new_operands[i]] =
body->AddInstruction(HloInstruction::CreateGetTupleElement(
new_operands[i]->shape(), new_operand_tuple_inside, i));
add_users_if_available(new_operands[i]);
}
// HLOs to sink without operands.
std::vector<HloInstruction*> nullaries_to_sink;
for (auto inst : to_sink) {
if (inst->operand_count() == 0) {
nullaries_to_sink.push_back(inst);
}
}
// Sort nullaries_to_sink to make it deterministic.
absl::c_sort(nullaries_to_sink,
[](const HloInstruction* a, const HloInstruction* b) {
return a->unique_id() < b->unique_id();
});
for (auto inst : nullaries_to_sink) {
worklist.push_back(inst);
}
while (!worklist.empty()) {
auto inst = worklist.back();
worklist.pop_back();
std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
for (int64 i = 0; i < inst->operand_count(); ++i) {
inst_new_operands[i] = outside_to_inside[inst->operand(i)];
}
outside_to_inside[inst] = body->AddInstruction(
inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
add_users_if_available(inst);
}
TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
for (auto ou : old_body_param_users) {
if (ou->opcode() == HloOpcode::kGetTupleElement &&
ou->tuple_index() == non_windowed_operand_index) {
TF_RETURN_IF_ERROR(
ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
}
}
return Status::OK();
}
// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into
// the windowed dot-general loop on non-contracting dimensions. Such a loop has
// a dynamic-update-slice at the output. If we move the user nodes into the loop
// and before the dynamic-update-slice, the user nodes can operate on smaller
// shapes, which reduces memory.
//
// small_operands small_operands
// | | => | |
// | | loop { loop { | |
// | | conv | broadcast conv
// | | | | | /
// | | dynamic-update-slice | dynamic-slice /
// | | | | | /
// | | } | | multiply-----
// |broadcast / | /
// | | / reduce
// |multiply-- |
// \ | dynamic-update-slice
// reduce }
//
// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
// with the input nodes (broadcast).
Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
HloInstruction* loop) {
CHECK_EQ(loop->user_count(), 1);
// There should be a single direct user of the while loop, which is the
// gte for element 2, i.e., the dot output.
auto user_gte = loop->users().front();
CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
CHECK_EQ(user_gte->tuple_index(), 2);
auto computation = loop->parent();
// Find the reduce outputs and the input nodes they depend on, if input nodes
// only have small operands.
std::unordered_set<HloInstruction*> to_move;
std::vector<HloInstruction*> new_operands;
std::unordered_set<const HloInstruction*> new_operands_set;
std::vector<HloInstruction*> reduce_outputs;
std::vector<HloInstruction*> worklist;
Shape padded_shape = user_gte->shape();
Shape unpadded_shape = user_gte->shape();
auto original_output = user_gte;
if (user_gte->user_count() == 1 &&
user_gte->users().back()->opcode() == HloOpcode::kSlice) {
original_output = user_gte->users().back();
unpadded_shape = original_output->shape();
}
for (auto u : original_output->users()) {
worklist.push_back(u);
}
to_move.insert(original_output);
while (!worklist.empty()) {
auto inst = worklist.back();
worklist.pop_back();
if (to_move.count(inst) > 0) {
continue;
}
// We only support reduces with simple reduction function, since we may need
// to accumulate across iterations manually.
if (inst->opcode() == HloOpcode::kReduce &&
inst->to_apply()->instruction_count() == 3 &&
inst->to_apply()->num_parameters() == 2 &&
inst->to_apply()->root_instruction()->IsElementwise()) {
to_move.insert(inst);
auto other_operand = inst->mutable_operand(1);
auto res = new_operands_set.emplace(other_operand);
if (res.second) {
new_operands.push_back(other_operand);
}
reduce_outputs.push_back(inst);
} else if (inst != computation->root_instruction() &&
inst->user_count() > 0 && inst->IsElementwise() &&
!inst->HasSideEffectNoRecurse() &&
inst->opcode() != HloOpcode::kAllReduce &&
absl::c_all_of(inst->operands(),
[inst](const HloInstruction* o) {
return ShapeUtil::CompatibleIgnoringElementType(
o->shape(), inst->shape());
})) {
// For an elementwise op, we need to make sure that they depend on only
// nodes already in to_move and nodes with small operands.
bool can_include = true;
for (auto operand : inst->operands()) {
if (to_move.count(operand) > 0) {
continue;
}
auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
if (find_result.first.empty()) {
can_include = false;
break;
}
for (auto n : find_result.first) {
to_move.insert(n);
}
for (auto new_operand : find_result.second) {
auto res = new_operands_set.insert(new_operand);
if (res.second) {
new_operands.push_back(new_operand);
}
}
}
if (!can_include) {
to_move.clear();
break;
}
to_move.insert(inst);
for (auto u : inst->users()) {
worklist.push_back(u);
}
} else {
to_move.clear();
break;
}
}
// If nothing is found, to_move could contain only original_output, or cleared
// by the above code.
if (to_move.size() <= 1) {
return Status::OK();
}
// We will replace the original loop output with reduce-shape outputs. Create
// the initial buffers before the loop.
for (auto out : reduce_outputs) {
auto padded_out_shape = out->shape();
int64 operand_dim = 0;
int64 output_dim = 0;
while (output_dim < padded_out_shape.rank()) {
if (absl::c_linear_search(out->dimensions(), operand_dim)) {
// Dimension colapsed.
++operand_dim;
continue;
}
// Kept dimensions have the same size of the padded shape.
padded_out_shape.set_dimensions(output_dim,
padded_shape.dimensions(operand_dim));
++operand_dim;
++output_dim;
}
auto broadcast =
computation->AddInstruction(HloInstruction::CreateBroadcast(
padded_out_shape,
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(out->shape().element_type()))),
{}));
new_operands.push_back(broadcast);
}
auto input_tuple = loop->mutable_operand(0);
// Create the new input subtuple that contains the small operands and the
// reduce-shape result buffers.
auto new_input_subtuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
TF_RETURN_IF_ERROR(
input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
auto body = loop->while_body();
auto body_param = body->parameter_instruction(0);
auto body_root = body->root_instruction();
CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
// Update tuple shapes.
for (auto tuple : std::vector<HloInstruction*>{
input_tuple, loop, loop->while_condition()->parameter_instruction(0),
body_param, body_root}) {
*ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
new_input_subtuple->shape();
}
auto new_loop_input =
body->AddInstruction(HloInstruction::CreateGetTupleElement(
new_input_subtuple->shape(), body_param, 2));
// Now create the moved nodes inside the loop body.
std::unordered_map<const HloInstruction*, HloInstruction*> outside_to_inside;
worklist.clear();
auto add_users_if_available = [&](HloInstruction* inst) {
for (auto u : inst->users()) {
if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 &&
absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
return outside_to_inside.count(o) > 0;
})) {
worklist.push_back(u);
}
}
};
for (int64 i = 0; i < new_operands.size(); ++i) {
outside_to_inside[new_operands[i]] =
body->AddInstruction(HloInstruction::CreateGetTupleElement(
new_operands[i]->shape(), new_loop_input, i));
add_users_if_available(new_operands[i]);
}
// The elementwise nodes will be created with sliced shape. The original loop
// output corresponds to the dynamic-update-slice's update slice.
auto dus = body_root->mutable_operand(2);
CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice);
outside_to_inside[original_output] = dus->mutable_operand(1);
add_users_if_available(original_output);
std::vector<HloInstruction*> slice_offsets(padded_shape.rank());
for (int64 i = 0; i < slice_offsets.size(); ++i) {
slice_offsets[i] = dus->mutable_operand(i + 2);
}
auto get_slice = [&](HloInstruction* padded) {
return body->AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
padded->shape().element_type()),
padded, slice_offsets, dus->operand(1)->shape().dimensions()));
};
// Helper functions to create nodes with small operands.
auto add_broadcast = [&](const HloInstruction* broadcast) {
auto padded_operand_shape = broadcast->operand(0)->shape();
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
padded_operand_shape.set_dimensions(
i, padded_shape.dimensions(broadcast->dimensions(i)));
}
auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)],
padded_operand_shape, nullptr, body);
outside_to_inside[broadcast] =
get_slice(body->AddInstruction(broadcast->CloneWithNewOperands(
ShapeUtil::ChangeElementType(padded_shape,
padded_operand_shape.element_type()),
{padded_operand})));
};
auto add_iota = [&](const HloInstruction* iota) {
outside_to_inside[iota] =
get_slice(body->AddInstruction(iota->CloneWithNewOperands(
ShapeUtil::ChangeElementType(padded_shape,
iota->shape().element_type()),
{})));
};
auto add_constant = [&](const HloInstruction* constant) {
outside_to_inside[constant] = body->AddInstruction(constant->Clone());
outside_to_inside[constant] = get_slice(
PadToShape(outside_to_inside[constant],
ShapeUtil::ChangeElementType(
padded_shape, constant->shape().element_type()),
nullptr, body));
};
while (!worklist.empty()) {
auto inst = worklist.back();
worklist.pop_back();
if (outside_to_inside.count(inst) > 0) {
continue;
}
if (inst->opcode() == HloOpcode::kBroadcast) {
add_broadcast(inst);
} else if (inst->opcode() == HloOpcode::kIota) {
add_iota(inst);
} else if (inst->opcode() == HloOpcode::kConstant) {
add_constant(inst);
} else if (inst->opcode() == HloOpcode::kReduce) {
// This is an output, for which we has special handling later.
} else {
std::vector<HloInstruction*> operands_inside(inst->operand_count());
for (int64 i = 0; i < operands_inside.size(); ++i) {
operands_inside[i] = outside_to_inside[inst->operand(i)];
}
outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands(
ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
inst->shape().element_type()),
operands_inside));
}
add_users_if_available(inst);
}
std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
for (int64 i = 0; i < new_outputs_inside.size(); ++i) {
new_outputs_inside[i] = outside_to_inside[new_operands[i]];
}
// Now create the reduce outpus inside of the loop.
for (int64 i = 0; i < reduce_outputs.size(); ++i) {
auto reduce_outside = reduce_outputs[i];
CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
auto last_iter_result = outside_to_inside[new_operands[index_in_operand]];
auto operand0 = outside_to_inside[reduce_outside->operand(0)];
auto operand1 = outside_to_inside[reduce_outside->operand(1)];
TF_ASSIGN_OR_RETURN(auto reduce_shape,
ShapeInference::InferReduceShape(
{&operand0->shape(), &operand1->shape()},
reduce_outside->dimensions(),
reduce_outside->to_apply()->ComputeProgramShape()));
*reduce_shape.mutable_layout() = reduce_outside->shape().layout();
std::vector<HloInstruction*> reduce_dus_offsets;
// If any collapsed dimension is windowed, we need to accumulate with last
// iteration's result. If such a dimension has padding, we also need to mask
// off invalid data.
bool needs_accumulate = false;
std::vector<int64> dims_to_mask;
for (int64 i = 0; i < slice_offsets.size(); ++i) {
if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
if (reduce_outside->operand(0)->shape().dimensions(i) !=
operand0->shape().dimensions(i)) {
needs_accumulate = true;
if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
dims_to_mask.push_back(i);
}
}
continue;
}
reduce_dus_offsets.push_back(slice_offsets[i]);
}
// Mask off invalid data in collapsed dimensions.
for (int64 dim : dims_to_mask) {
auto iota = body->AddInstruction(HloInstruction::CreateIota(
ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
auto add = body->AddInstruction(HloInstruction::CreateBinary(
iota->shape(), HloOpcode::kAdd, iota,
body->AddInstruction(HloInstruction::CreateBroadcast(
iota->shape(), slice_offsets[dim], {}))));
auto limit = body->AddInstruction(HloInstruction::CreateBroadcast(
iota->shape(),
body->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
reduce_outside->operand(0)->shape().dimensions(dim)))),
{}));
auto compare = body->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
ComparisonDirection::kLt));
operand0 = body->AddInstruction(HloInstruction::CreateTernary(
operand0->shape(), HloOpcode::kSelect, compare, operand0,
body->AddInstruction(HloInstruction::CreateBroadcast(
operand0->shape(), operand1, {}))));
}
auto output_inside =
body->AddInstruction(reduce_outside->CloneWithNewOperands(
reduce_shape, {operand0, operand1}));
// Accumulate with previous results if needed.
if (needs_accumulate) {
auto input_slice =
body->AddInstruction(HloInstruction::CreateDynamicSlice(
output_inside->shape(), last_iter_result, reduce_dus_offsets,
output_inside->shape().dimensions()));
output_inside = body->AddInstruction(HloInstruction::CreateBinary(
output_inside->shape(),
reduce_outside->to_apply()->root_instruction()->opcode(),
output_inside, input_slice));
}
// Dynamic-update-slice if needed.
if (!ShapeUtil::Compatible(output_inside->shape(),
last_iter_result->shape())) {
output_inside =
body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
last_iter_result->shape(), last_iter_result, output_inside,
reduce_dus_offsets));
}
new_outputs_inside[index_in_operand] = output_inside;
}
// Body output.
auto new_output_inside =
body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
TF_RETURN_IF_ERROR(
body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus));
// Replace uses of the reduces outside the loop.
auto new_output_gte =
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
new_output_inside->shape(), loop, 2));
for (int64 i = 0; i < reduce_outputs.size(); ++i) {
int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
auto new_output =
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
new_outputs_inside[index_in_operand]->shape(), new_output_gte,
index_in_operand));
if (!ShapeUtil::Compatible(new_output->shape(),
reduce_outputs[i]->shape())) {
new_output = computation->AddInstruction(HloInstruction::CreateSlice(
reduce_outputs[i]->shape(), new_output,
std::vector<int64>(new_output->shape().rank(), 0),
reduce_outputs[i]->shape().dimensions(),
std::vector<int64>(new_output->shape().rank(), 1)));
}
TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
}
return Status::OK();
}
} // namespace
Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
HloComputation* computation) {
for (auto& loop : windowed_dot_general_loops_) {
if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) {
// We have a dynamic-slice for the non-windowed operand in
// batch/contracting-dim windowed dot-general. So moving the
// broadcast/iota/elementwise ops into the loop could help reduce memory
// via fusion.
TF_RETURN_IF_ERROR(
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
loop.while_loop, 1 - loop.windowed_operand));
}
if (!loop.windowed_in_contracting_dims) {
// We have a dynamic-update-slice for the output in
// batch/non-contracting-dim windowed dot-general. So moving reduce ops
// into the loop could help reduce memory.
TF_RETURN_IF_ERROR(
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
loop.while_loop));
}
}
return Status::OK();
}
StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
HloComputation* computation, const HloSharding& root_sharding) {
VLOG(2) << "Partitioning computation " << computation->name() << " for "
<< num_replicas_ << " replicas and " << num_partitions_
<< " partitions";
TF_RETURN_IF_ERROR(computation->Accept(this));
HloModule* module = computation->parent();
auto new_root =
GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
auto new_computation =
module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation));
// Replace the original computation with the new SPMD computation.
std::unordered_map<HloComputation*, HloComputation*> replacement;
replacement[computation] = new_computation;
module->ReplaceComputations(replacement);
return changed_;
}
Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
return Unimplemented(
"PartitionId instruction is not supported for SPMD partitioning since "
"the meaning is ambiguous -- whether the instruction is replicated or "
"the data is replicated, and if the latter which data is replicated.");
}
SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
int64 num_replicas) {
return {
[](SpmdBuilder* b) {
return b->AddInstruction(HloInstruction::CreatePartitionId());
},
[num_replicas](SpmdBuilder* b, HloInstruction* operand,
HloComputation* reduction, int64 channel_id) {
return b->AddInstruction(HloInstruction::CreateAllReduce(
operand->shape(), {operand}, reduction,
CreateReplicaGroups(num_replicas),
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/false));
},
[](SpmdBuilder* b, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
int64 channel_id) {
return b->AddInstruction(HloInstruction::CreateCollectivePermute(
operand->shape(), operand, src_dst_pairs, channel_id));
},
[](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension) {
std::vector<Shape> shapes(operands.size(), operands[0]->shape());
const Shape output_shape = (shapes.size() == 1)
? shapes[0]
: ShapeUtil::MakeTupleShape(shapes);
return b->AddInstruction(HloInstruction::CreateAllToAll(
output_shape, operands, replica_groups,
/*constrain_layout=*/false, channel_id, split_dimension));
},
[num_replicas, num_partitions](
SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, int64 all_gather_dimension) {
std::vector<ReplicaGroup> device_groups;
device_groups.reserve(partition_subgroups.size() * num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
for (const auto& pgroup : partition_subgroups) {
device_groups.emplace_back();
for (int64 pid : pgroup) {
device_groups.back().add_replica_ids(i * num_partitions + pid);
}
}
}
return b->AddInstruction(HloInstruction::CreateAllGather(
ag_shape, operand, all_gather_dimension, device_groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/true));
},
};
}
SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options)
: SpmdPartitioner(
num_partitions, num_replicas, std::move(options),
GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
HloInstruction* operand,
const HloSharding& sharding,
int64 channel_id) {
CHECK(!sharding.IsTileMaximal());
// Add one leading dimension to gather all partitions.
std::vector<int64> shape;
shape.push_back(1);
for (int64 dim : operand->shape().dimensions()) {
shape.push_back(dim);
}
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
std::vector<std::vector<int64>> partition_subgroups(1);
for (int64 pid : sharding.tile_assignment()) {
partition_subgroups[0].push_back(pid);
}
shape[0] = sharding.tile_assignment().num_elements();
auto result = collective_ops_creator_.create_cross_partition_all_gather(
b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, channel_id, /*all_gather_dimension=*/0);
// If n > 1 dimensions are partitioned, split the leading dimension to n.
std::vector<int64> tiled_dims;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (sharding.tile_assignment().dim(i) > 1) {
tiled_dims.push_back(i);
}
}
if (tiled_dims.size() > 1) {
std::vector<int64> split_dim_shape;
split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
for (int64 i : tiled_dims) {
split_dim_shape.push_back(sharding.tile_assignment().dim(i));
}
for (int64 dim : operand->shape().dimensions()) {
split_dim_shape.push_back(dim);
}
result = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
result));
}
// Transpose the gathered dimensions to next to their corresponding
// partitioned dimensions.
std::vector<int64> xpose_permutation(result->shape().rank());
int64 split_dims_added = 0;
for (int64 i = 0; i < xpose_permutation.size(); ++i) {
if (sharding.tile_assignment().dim(i - split_dims_added) == 1) {
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
} else {
xpose_permutation[i] = split_dims_added;
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
split_dims_added++;
i++;
}
}
result = b->AddInstruction(HloInstruction::CreateTranspose(
ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
.ValueOrDie(),
result, xpose_permutation));
// Reshape to the desired shape.
auto ag_shape = operand->shape();
for (int64 i : tiled_dims) {
ag_shape.set_dimensions(
i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
}
result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
return result;
}
StatusOr<bool> SpmdPartitioner::PartitionComputation(
HloComputation* computation, const HloSharding& root_sharding,
int64* next_channel_id, SpmdLogger* logger) {
auto visitor =
CreateVisitor(computation, num_partitions_, num_replicas_,
collective_ops_creator_, next_channel_id, logger, options_);
return visitor->DoPartition(computation, root_sharding);
}
std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options) {
return absl::make_unique<SpmdPartitioningVisitor>(
computation, num_partitions, num_replicas, collective_ops_creator,
next_channel_id, logger, std::move(options), this);
}
StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
TF_RETURN_IF_ERROR(PreprocessSharding(module));
XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
*module, options_.report_instruction_count));
// Add the parameters' and output's shardings to the module.
std::vector<HloSharding> entry_params_shardings;
for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) {
auto param = module->entry_computation()->parameter_instruction(i);
CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
entry_params_shardings.push_back(param->sharding());
}
module->set_spmd_parameters_shardings(entry_params_shardings);
auto entry_root = module->entry_computation()->root_instruction();
CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
module->set_spmd_output_sharding(entry_root->sharding());
FlattenCallGraph flatten;
TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
SpmdLogger logger(options_.report_instruction_count);
auto program_shape = module->entry_computation()->ComputeProgramShape();
int64 next_channel_id = hlo_query::NextChannelId(*module);
TF_ASSIGN_OR_RETURN(
bool partition_changed,
PartitionComputation(
module->entry_computation(),
module->entry_computation()->root_instruction()->sharding(),
&next_channel_id, &logger));
changed |= partition_changed;
// For the entry computation, make sure that the root instruction and the
// parameters preserve their signatures.
auto new_program_shape = module->entry_computation()->ComputeProgramShape();
if (!options_.allow_module_signature_change) {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
program_shape.result(), new_program_shape.result()))
<< "Result shape changed for the entry computation";
TF_RET_CHECK(program_shape.parameters_size() ==
new_program_shape.parameters_size())
<< "Parameter count changed for the entry computation";
for (int64 i = 0; i < program_shape.parameters_size(); ++i) {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
program_shape.parameters(i), new_program_shape.parameters(i)))
<< "Parameter shape changed for the entry computation";
}
} else {
const auto& old_entry_layout = module->entry_computation_layout();
// Shapes can change but the layout should still remain the same.
for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) {
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
old_entry_layout.parameter_shape(i),
new_program_shape.mutable_parameters(i)));
}
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
old_entry_layout.result_shape(), new_program_shape.mutable_result()));
HloModuleConfig config = module->config();
*config.mutable_entry_computation_layout() =
ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
module->set_config(config);
}
XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
*module, options_.report_instruction_count));
XLA_VLOG_LINES(1, logger.MakeReport());
if (changed) {
HloPassPipeline pass("spmd-cleanup");
pass.AddPass<TupleSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
pass.AddPass<FlattenCallGraph>();
TF_RETURN_IF_ERROR(pass.Run(module).status());
}
TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
return changed;
}
Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* hlo : computation->instructions()) {
if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
TF_RET_CHECK(hlo->has_sharding())
<< "Side-effect HLO must have sharding: " << hlo->ToString();
TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
hlo->opcode() == HloOpcode::kInfeed)
<< "Non-infeed side-effect HLO cannot have a replicated sharding:"
<< hlo->ToString();
}
// For unassigned HLOs, annotate with replicated sharding.
//
// Among side-effecting ops, only Rng is allowed to omit the annotation.
// In that case, we currently force it to run on core 0, since we don't
// support partitioning or replicating the Rng op (the values depend on
// the seed provided to each device).
//
// TODO(hyouklee): Should we also convert single-device shardings (without
// side-effects) into replicated?
if (!hlo->has_sharding()) {
if (hlo->opcode() == HloOpcode::kRng) {
hlo->set_sharding(HloSharding::AssignDevice(0));
} else {
hlo->set_sharding(
HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
}
} else if (!hlo->sharding().IsTileMaximal()) {
std::vector<int64> available(num_partitions_);
std::iota(available.begin(), available.end(), 0);
TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
hlo->sharding(), available)
.size())
<< "num_partitions:" << num_partitions_ << "\n"
<< "SPMD partitioner only supports tile sharding that includes all "
"partitions. If you didn't add this sharding annotation in the "
"model, please file a bug to XLA team.\n"
<< hlo->ToString();
}
}
}
// Entry computation's parameter and root sharding must be either all
// replicated or all on a single device.
if (!options_.allow_module_signature_change) {
const HloComputation* entry = module->entry_computation();
TF_RET_CHECK(entry->root_instruction()->has_sharding());
const HloSharding& root_sharding = entry->root_instruction()->sharding();
TF_RET_CHECK(root_sharding.IsReplicated() ||
root_sharding.UniqueDevice().has_value())
<< "Unsupported entry root sharding: " << root_sharding.ToString();
for (const HloInstruction* param : entry->parameter_instructions()) {
TF_RET_CHECK(param->has_sharding());
TF_RET_CHECK(param->sharding().IsReplicated() ||
param->sharding().UniqueDevice().has_value())
<< "Unsupported entry parameter sharding:"
<< param->sharding().ToString();
}
}
return Status::OK();
}
} // namespace spmd
} // namespace xla