blob: 72561cb249b44e57b1ebe852da8401dac9895f10 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <numeric>
#include <ostream>
#include <set>
#include <string>
#include <tuple>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
return out;
}
BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
const LogicalBuffer& buffer,
bool mandatory, bool dfs)
: LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
}
string BufferLayoutConstraint::ToString() const {
return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
LayoutUtil::HumanString(layout_));
}
OperandLayoutConstraint::OperandLayoutConstraint(
const ShapeLayout& shape_layout, const HloInstruction* instruction,
int64 operand_no, bool mandatory, bool dfs)
: LayoutConstraint(mandatory, dfs),
shape_layout_(shape_layout),
instruction_(instruction),
operand_no_(operand_no) {
CHECK(shape_layout_.LayoutIsSet());
CHECK(ShapeUtil::Compatible(shape_layout.shape(),
instruction->operand(operand_no)->shape()))
<< shape_layout.shape() << " is not compatible with "
<< instruction->operand(operand_no)->shape() << " (for operand "
<< operand_no << " of instruction " << instruction->ToString() << ")";
}
string OperandLayoutConstraint::ToString() const {
return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
instruction_->name(), operand_no_,
shape_layout_.ToString());
}
string ResultLayoutConstraint::ToString() const {
return absl::StrFormat("ResultLayoutConstraint: %s",
shape_layout_.ToString());
}
LayoutConstraints::LayoutConstraints(
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation)
: points_to_analysis_(points_to_analysis), computation_(computation) {
// Gather all array-shaped logical buffers into unconstrained_buffer_ids.
for (HloInstruction* inst : computation_->instructions()) {
points_to_analysis_.GetPointsToSet(inst).ForEachElement(
[&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
for (const LogicalBuffer* buffer : buffers) {
// The points to analysis is computed per module, restrict
// constraints to array buffers in this computation.
if (buffer->IsArray() &&
buffer->instruction()->parent() == computation) {
unconstrained_buffer_ids_.insert(buffer->id());
}
}
});
}
}
PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
const HloInstruction* instruction) const {
auto it = buffer_sets_cache_.find(instruction);
if (it != buffer_sets_cache_.end()) {
return it->second.get();
}
auto& buffer_set =
buffer_sets_cache_
.emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
.first->second;
const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
points_to_set.ForEachElement(
[&buffer_set](const ShapeIndex& /*index*/,
const PointsToSet::BufferList& buffers) {
buffer_set->insert(buffers.begin(), buffers.end());
});
return buffer_set.get();
}
bool LayoutConstraints::OperandBufferForwarded(
const HloInstruction* instruction, int64 operand_no) const {
// The operand is potentially forwarded if the intersection of points-to sets
// of the operand and the instruction is non-empty.
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
PointsToSet::BufferSet* operand_buffers =
GetBufferSet(instruction->operand(operand_no));
return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
return operand_buffers->count(b) > 0;
});
}
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
const LogicalBuffer& buffer,
bool mandatory, bool dfs) {
VLOG(3) << "SetBufferLayout : " << buffer << " : "
<< LayoutUtil::HumanString(layout);
TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
if (!buffer.IsArray()) {
return FailedPrecondition(
"Layout of buffer %s cannot be constrained because buffer is not "
"array-shaped, has shape: %s",
buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
}
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
auto iter = buffer_constraints_.find(&buffer);
if (iter != buffer_constraints_.end()) {
const BufferLayoutConstraint& curr_constraint = iter->second;
if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
if (curr_constraint.mandatory()) {
if (!mandatory) {
VLOG(3) << "Buffer" << buffer
<< " already has a mandatory layout constrain, skipping";
return Status::OK();
}
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
LayoutUtil::HumanString(layout));
}
iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
} else {
TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
<< buffer.ToString();
iter = buffer_constraints_
.insert(std::make_pair(
&buffer,
BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
.first;
}
added_constraints_.push_back(&iter->second);
return Status::OK();
}
Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
const HloInstruction* instruction,
int64 operand_no, bool mandatory,
bool dfs) {
VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
<< operand_no << " : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
const OperandLayoutConstraint* curr_shape_layout =
GetOperandLayoutConstraint(instruction, operand_no);
if (curr_shape_layout != nullptr) {
if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
shape_with_layout, /*minor_to_major_only=*/true)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
if (curr_shape_layout->mandatory()) {
return FailedPrecondition(
"Operand %d of instruction %s already has a layout constraint "
"%s, cannot add incompatible constraint %s",
operand_no, instruction->name(),
curr_shape_layout->shape_layout().ToString(),
ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
}
// If any buffers in the operand occur in the output of the instruction, then
// return an error. This case is not handled because such a constraint changes
// layouts beyond this immediate use and is complicated to handle.
if (OperandBufferForwarded(instruction, operand_no)) {
return FailedPrecondition(
"Cannot constraint layout of operand %d of instruction %s "
"because instruction forwards operand's LogicalBuffer(s)",
operand_no, instruction->name());
}
auto key = std::make_pair(instruction, operand_no);
auto iter = operand_constraints_.find(key);
if (iter == operand_constraints_.end()) {
auto pair = std::make_pair(
key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
instruction, operand_no, mandatory, dfs));
iter = operand_constraints_.insert(pair).first;
} else {
iter->second =
OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
operand_no, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
return Status::OK();
}
Status LayoutConstraints::SetArrayOperandLayout(
const Layout& layout, const HloInstruction* instruction, int64 operand_no,
bool mandatory, bool dfs) {
const HloInstruction* operand = instruction->operand(operand_no);
TF_RET_CHECK(operand->shape().IsArray());
Shape shape(operand->shape());
*shape.mutable_layout() = layout;
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
}
Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
bool dfs) {
VLOG(3) << "SetResultLayout : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
const ShapeLayout* curr_shape_layout = ResultLayout();
if (curr_shape_layout != nullptr) {
if (!curr_shape_layout->MatchesLayoutInShape(
shape_with_layout, /*minor_to_major_only=*/true)) {
return FailedPrecondition(
"Result of computation %s already has the layout constraint %s, "
"cannot add incompatible constraint %s",
computation_->name(), curr_shape_layout->ToString(),
ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
result_constraint_.reset(
new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
added_constraints_.push_back(result_constraint_.get());
return Status::OK();
}
Status LayoutConstraints::SetInstructionLayout(
const Shape& shape_with_layout, const HloInstruction* instruction,
bool mandatory, bool dfs) {
VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
return FailedPrecondition(
"Instruction %s of shape %s cannot be assigned incompatible layout %s",
instruction->name(), ShapeUtil::HumanString(instruction->shape()),
ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// Create a BufferLayoutConstraint for each array shape in the output of the
// instruction.
return ShapeUtil::ForEachSubshapeWithStatus(
shape_with_layout,
[this, instruction, mandatory](const Shape& subshape,
const ShapeIndex& index) -> Status {
// The precondition for this method is that the instruction defines all
// buffers in its output.
auto buffers =
points_to_analysis_.GetPointsToSet(instruction).element(index);
CHECK_EQ(1, buffers.size());
CHECK_EQ(buffers[0]->instruction(), instruction);
if (subshape.IsArray() && subshape.has_layout()) {
return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
} else {
return Status::OK();
}
});
}
const Layout* LayoutConstraints::BufferLayout(
const LogicalBuffer& buffer) const {
if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
return &constraint->layout();
}
return nullptr;
}
const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
const LogicalBuffer& buffer) const {
auto it = buffer_constraints_.find(&buffer);
return it == buffer_constraints_.end() ? nullptr : &it->second;
}
const ShapeLayout* LayoutConstraints::OperandLayout(
const HloInstruction* instruction, int64 operand_no) const {
if (const auto* constraint =
GetOperandLayoutConstraint(instruction, operand_no)) {
return &constraint->shape_layout();
}
return nullptr;
}
const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
const HloInstruction* instruction, int64 operand_no) const {
auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
return it == operand_constraints_.end() ? nullptr : &it->second;
}
const ShapeLayout* LayoutConstraints::ResultLayout() const {
return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
}
string LayoutConstraints::ToString() const {
string output;
absl::StrAppend(&output, "LayoutConstraints for computation ",
computation_->name(), ":\n");
for (auto* instruction : computation_->MakeInstructionPostOrder()) {
absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
for (int64 i = 0; i < instruction->operand_count(); ++i) {
if (OperandLayout(instruction, i) != nullptr) {
absl::StrAppend(&output, " operand (", i,
"): ", OperandLayout(instruction, i)->ToString(), "\n");
}
}
for (const LogicalBuffer* buffer :
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
if (BufferLayout(*buffer) != nullptr) {
absl::StrAppend(&output, " ", buffer->ToString(), " : ",
LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
}
}
}
if (ResultLayout() != nullptr) {
absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
}
return output;
}
namespace {
bool IsHostSendRecv(const HloInstruction* instruction) {
const HloSendRecvInstruction* send_recv_instr =
DynCast<HloSendRecvInstruction>(instruction);
return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
}
} // namespace
Status LayoutAssignment::BuildHostChannelConstraints(
HloComputation* computation) {
for (auto* instruction : computation->instructions()) {
const HloSendRecvInstruction* send_recv_instr =
DynCast<HloSendRecvInstruction>(instruction);
if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
continue;
}
// For host transfers the Send and Recv instruction carry the layout.
if (instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv) {
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
TF_RET_CHECK(data_shape.IsArray());
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
*send_recv_instr->channel_id(), data_shape.layout());
TF_RET_CHECK(prev_layout == nullptr)
<< "Cannot constrain host transfer layout as it was set to "
<< LayoutUtil::HumanString(*prev_layout) << ": "
<< send_recv_instr->ToString();
}
}
return Status::OK();
}
namespace {
bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);
return custom_call != nullptr && custom_call->layout_constrained();
}
bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
const HloCollectiveInstruction* collective =
DynCast<HloCollectiveInstruction>(instruction);
return collective != nullptr && collective->constrain_layout();
}
} // namespace
Status LayoutAssignment::AddMandatoryConstraints(
const ComputationLayout* computation_layout,
ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
LayoutConstraints* constraints) {
VLOG(3) << "Adding mandatory layout constraints to computation "
<< computation->name();
auto get_channel_constraints = [&](const HloInstruction* instruction) {
return IsHostSendRecv(instruction) ? &host_channel_constraints_
: channel_constraints;
};
// Constrain layouts of instructions which define values with pre-existing
// layouts.
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kInfeed) {
// Infeed layouts must match the layout of the original inserted
// instruction.
// TODO(b/31425034): Change infeeds to be more like parameters, with
// shapes in the ComputationLayout.
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(instruction->shape(), instruction));
} else if (instruction->opcode() == HloOpcode::kOutfeed) {
// Constrain the input to the Outfeed instruction to be the expected
// layout of the Outfeed.
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
instruction->outfeed_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
if (computation_layout != nullptr) {
const ShapeLayout& parameter_layout =
computation_layout->parameter_layout(
instruction->parameter_number());
// Parameter layouts must match the respective layout in
// ComputationLayout, if there is one.
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
parameter_layout.shape(), instruction));
}
} else if (IsLayoutConstrainedCustomCall(instruction)) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(custom_call->shape(), custom_call));
for (int64 i = 0; i < custom_call->operand_count(); ++i) {
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
custom_call->operand_shapes_with_layout()[i], custom_call, i));
}
} else if (instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = *instruction->channel_id();
if (!get_channel_constraints(instruction)
->IsChannelConstrained(channel_id)) {
continue;
}
if (instruction->opcode() == HloOpcode::kSend) {
// TODO(b/68493863): Change to use SetOperandLayout().
const Shape send_buffer_shape = instruction->operand(0)->shape();
TF_RET_CHECK(send_buffer_shape.IsArray());
Shape new_buffer_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(send_buffer_shape,
*instruction->channel_id());
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
new_buffer_shape, instruction->operand(0)));
} else {
const Shape recv_buffer_shape =
ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
TF_RET_CHECK(recv_buffer_shape.IsArray());
TF_ASSIGN_OR_RETURN(
const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
{0}));
Shape new_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(recv_buffer_shape,
*instruction->channel_id());
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
} else if (IsLayoutConstrainedCollective(instruction)) {
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(instruction->shape(), instruction));
} else if (instruction->IsCrossModuleAllReduce()) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id().value();
if (!get_channel_constraints(instruction)
->IsChannelConstrained(channel_id)) {
continue;
}
// TODO(b/68493863): Change to use SetOperandLayout().
const Shape& buffer_shape = instruction->operand(0)->shape();
TF_RET_CHECK(buffer_shape.IsArray());
Shape new_buffer_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(buffer_shape, channel_id);
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
// Constrain layouts of instructions which call computations which have
// already been assigned layouts. Instructions which call computations in a
// parallel element-wise context (eg, map or reduce) do not need layout
// constraints because they operate on scalars.
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCall) {
// kCall instruction operands and output must match the ComputationLayout
// of the called computation.
const ComputationLayout& called_computation_layout =
FindOrDie(computation_layouts_, instruction->to_apply());
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
called_computation_layout.result_layout().shape(), instruction));
TF_RET_CHECK(instruction->operand_count() ==
called_computation_layout.parameter_count());
for (int64 i = 0; i < instruction->operand_count(); ++i) {
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
called_computation_layout.parameter_layout(i).shape(), instruction,
i));
}
} else if (instruction->opcode() == HloOpcode::kWhile) {
// Layout of input and output of kWhile instruction must be equal and must
// match both input and output of body computation. Also, the input of
// condition computation must match kWhile layout.
HloComputation* body = instruction->while_body();
HloComputation* condition = instruction->while_condition();
const HloInstruction* init = instruction->operand(0);
ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
ComputationLayout& condition_layout =
FindOrDie(computation_layouts_, condition);
// Check a few invariants irrespective of layout.
CHECK_EQ(1, instruction->operand_count());
CHECK_EQ(1, body->num_parameters());
CHECK_EQ(1, condition->num_parameters());
DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
body_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
condition_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
<< " while=" << instruction->name()
<< " shape=" << body_layout.result_layout().ToString();
*body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
}
if (condition_layout.parameter_layout(0) !=
body_layout.parameter_layout(0)) {
VLOG(2) << "Reset %while condition parameter layout: cond="
<< condition->name() << " while=" << instruction->name()
<< " shape=" << body_layout.parameter_layout(0).ToString();
*condition_layout.mutable_parameter_layout(0) =
body_layout.parameter_layout(0);
}
// Constrain the output and the operand of the while instruction to match
// the computations.
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
body_layout.result_shape(), instruction, 0));
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
body_layout.result_shape(), instruction));
} else if (instruction->opcode() == HloOpcode::kConditional) {
// Find the conditional branch with the most instructions and force all
// other computations to match that layout. A potentially better decision
// could count the number FLOPs or how constrained the layouts are.
int64 largest_branch = 0;
int64 largest_instruction_count =
instruction->branch_computation(0)->instruction_count();
for (int j = 1; j < instruction->branch_count(); ++j) {
const int64 instruction_count =
instruction->branch_computation(j)->instruction_count();
if (instruction_count > largest_instruction_count) {
largest_branch = j;
largest_instruction_count = instruction_count;
}
}
ComputationLayout& best_branch_computation_layout =
FindOrDie(computation_layouts_,
instruction->branch_computation(largest_branch));
for (int k = 0; k < instruction->branch_count(); ++k) {
// Visit the best branch first.
int j = (k + largest_branch) % instruction->branch_count();
TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
ComputationLayout& branch_computation_layout =
FindOrDie(computation_layouts_, instruction->branch_computation(k));
if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
best_branch_computation_layout.result_layout().shape(),
/*minor_to_major_only=*/true)) {
computation_layouts_.erase(instruction->branch_computation(k));
InsertOrDie(&conditional_mismatch_,
instruction->branch_computation(k),
best_branch_computation_layout);
} else {
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
branch_computation_layout.parameter_shape(0), instruction, k + 1,
/*mandatory=*/true));
}
}
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
best_branch_computation_layout.parameter_shape(0), instruction,
largest_branch + 1,
/*mandatory=*/true));
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
best_branch_computation_layout.result_shape(), instruction));
}
}
// Finally set the result layout to match ComputationLayout, if there is one.
if (conditional_mismatch_.count(computation) > 0) {
TF_RETURN_IF_ERROR(constraints->SetResultLayout(
FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
} else if (computation_layout != nullptr) {
const ShapeLayout& result_layout = computation_layout->result_layout();
if (result_layout.LayoutIsSet()) {
TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
}
}
return Status::OK();
}
namespace {
bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
}
// The operands of a call must match the layouts of parameters in the
// ComputationLayout, and the call instruction itself must match the result
// layout in the ComputationLayout.
Status CheckCallLayout(HloInstruction* call,
const ComputationLayout& computation_layout) {
HloComputation* computation = call->to_apply();
TF_RET_CHECK(computation->num_parameters() == call->operand_count());
for (int64 i = 0; i < computation->num_parameters(); ++i) {
TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
call->operand(i)->shape(), /*minor_to_major_only=*/true));
}
TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
call->shape(), /*minor_to_major_only=*/true));
return Status::OK();
}
// Operands of layout-constrained custom calls must match the expected
// constrained layouts.
Status CheckCustomCallLayout(HloInstruction* instruction) {
if (IsLayoutConstrainedCustomCall(instruction)) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);
for (int64 i = 0; i < custom_call->operand_count(); ++i) {
TF_RET_CHECK(
LayoutsInShapesEqual(custom_call->operand(i)->shape(),
custom_call->operand_shapes_with_layout()[i]));
}
}
return Status::OK();
}
// For a while instruction, all the following layouts must be the same:
// (1) init operand
// (2) condition computation parameter
// (3) body computation parameter
// (4) body computation result
// (5) while instruction result
Status CheckWhileLayout(HloInstruction* while_inst,
const ComputationLayout& condition_computation_layout,
const ComputationLayout& body_computation_layout) {
auto init_shape = while_inst->operand(0)->shape();
TF_RET_CHECK(
condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
return Status::OK();
}
Status CheckConditionalLayout(
HloInstruction* instruction,
absl::Span<const ComputationLayout> branch_computation_layouts) {
for (int j = 0; j < instruction->branch_count(); ++j) {
const HloInstruction* branch_operand = instruction->operand(j + 1);
TF_RET_CHECK(
branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
branch_computation_layouts[j].result_layout().shape(),
/*minor_to_major_only=*/true));
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->shape(), /*minor_to_major_only=*/true));
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->branch_computation(j)->root_instruction()->shape(),
/*minor_to_major_only=*/true));
TF_RET_CHECK(
branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
branch_operand->shape(), /*minor_to_major_only=*/true));
}
return Status::OK();
}
// Fusion parameters must match the layout of the fusion instructions operands,
// and the root of the fusion expression must match the layout of the fusion
// instruction.
Status CheckFusionLayout(HloInstruction* fusion) {
TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
fusion->fused_expression_root()->shape()));
for (int64 i = 0; i < fusion->operand_count(); ++i) {
TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
fusion->operand(i)->shape()));
}
return Status::OK();
}
// The layout of a parameter must match the respective layout in the
// computation's ComputationLayout.
Status CheckParameterLayout(HloInstruction* parameter,
const ComputationLayout& computation_layout) {
const ShapeLayout& parameter_layout =
computation_layout.parameter_layout(parameter->parameter_number());
return ShapeUtil::ForEachSubshapeWithStatus(
parameter_layout.shape(),
[&](const Shape& subshape, const ShapeIndex& shape_index) {
if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
!subshape.has_layout()) {
return Status::OK();
}
if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
subshape,
ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
parameter->ToString(), parameter_layout.ToString());
}
return Status::OK();
});
}
// The layout of a constant instruction must match the layout of its literal.
Status CheckConstantLayout(HloInstruction* constant) {
if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
return InternalError(
"constant instruction %s does not match the layout of its literal %s",
constant->ToString(),
ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
}
return Status::OK();
}
} // namespace
StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
const Shape& shape_with_layout, HloInstruction* instruction) {
TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
<< ShapeUtil::HumanString(shape_with_layout) << " "
<< ShapeUtil::HumanString(instruction->shape())
<< " instruction: " << instruction->ToString();
if (instruction->shape().IsTuple()) {
// Copy tuple elements which have differing layouts.
std::vector<HloInstruction*> element_copies;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
++i) {
const Shape& target_shape =
ShapeUtil::GetSubshape(shape_with_layout, {i});
const Shape& instr_shape =
ShapeUtil::GetSubshape(instruction->shape(), {i});
HloInstruction* gte = instruction->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
instr_shape)) {
// Shapes and layouts are equal, no need to copy.
element_copies.push_back(gte);
} else {
SetupCopiedInstruction(*instruction, gte, {i});
// Recurse to copy each element.
TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
CreateCopyWithNewLayout(target_shape, gte));
element_copies.push_back(element_copy);
}
}
// Gather element copies into a tuple with a new Tuple instruction.
HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
HloInstruction::CreateTuple(element_copies));
SetupCopiedInstruction(*instruction, tuple_copy, {});
LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
shape_with_layout, tuple_copy->mutable_shape()));
return tuple_copy;
} else if (instruction->shape().IsArray()) {
HloInstruction* copy =
instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
instruction->shape(), HloOpcode::kCopy, instruction));
RegisterAddedCopy(copy);
SetupCopiedInstruction(*instruction, copy, {});
LayoutUtil::ClearLayout(copy->mutable_shape());
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
shape_with_layout, copy->mutable_shape()));
return copy;
} else {
return FailedPrecondition(
"Can only copy array and tuple shaped instructions");
}
}
// Creates a copy of the given operand if the operand's layout does not match
// the given layout. This copy replaces the use in the given instruction. Tuple
// operands will be deep-copied.
Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
const ShapeLayout& operand_layout, HloInstruction* instruction,
int64 operand_no) {
HloInstruction* operand = instruction->mutable_operand(operand_no);
TF_RET_CHECK(operand_layout.LayoutIsSet());
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
operand->shape())) {
VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
<< instruction->ToString();
// Operand layout already matches our constraint. Nothing to do.
return Status::OK();
}
VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
<< operand_layout.ToString() << " in " << instruction->ToString();
// If the operand is only used by a conditional, do the copy inside the branch
// to avoid overhead for other branches.
if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
instruction->operand(operand_no)->user_count() == 1) {
auto branch_comp = instruction->branch_computation(operand_no - 1);
auto param = branch_comp->parameter_instruction(0);
*param->mutable_shape() = operand->shape();
auto param_users = param->users();
TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
CreateCopyWithNewLayout(operand_layout.shape(), param));
for (auto user : param_users) {
TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
}
VLOG(4) << "New copy of " << operand->ToString() << " is "
<< param_copy->ToString();
if (param == branch_comp->root_instruction()) {
branch_comp->set_root_instruction(param_copy,
/*accept_different_shape=*/true);
}
*FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
ShapeLayout(operand->shape());
return Status::OK();
}
TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
CreateCopyWithNewLayout(operand_layout.shape(), operand));
VLOG(4) << "New copy of " << operand->ToString() << " is "
<< operand_copy->ToString();
return instruction->ReplaceOperandWith(operand_no, operand_copy);
}
void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
HloInstruction* copy,
const ShapeIndex& index) {
if (instruction.has_sharding()) {
// If the index is empty, we want to copy the whole sharding, in case the
// sharding is a tuple sharding.
HloSharding sharding =
!index.empty() && instruction.sharding().IsTuple()
? instruction.sharding().GetSubSharding(instruction.shape(), index)
: instruction.sharding();
// We propagate the sharding to the copied instruction only if it is a
// special sharding, like tiled ones.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
auto device = sharding.UniqueDevice();
if (!device || HloSharding::IsReservedDevice(*device)) {
copy->set_sharding(sharding);
}
}
copy->set_metadata(instruction.metadata());
}
Status LayoutAssignment::CheckLayouts(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instruction : computation->instructions()) {
// Verify every instruction has a layout and the layout is valid for the
// shape.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
// Use points-to analysis to verify that every subshape element in the
// output of the instruction matches the layout of the logical buffer
// which could be the source of the subshape value.
const PointsToSet& points_to_set =
points_to_analysis->GetPointsToSet(instruction);
TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
[&instruction](ShapeIndex index,
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
const Shape& instruction_subshape =
ShapeUtil::GetSubshape(instruction->shape(), index);
for (const LogicalBuffer* buffer : buffers) {
if (!Shape::Equal()
.IgnoreDynamicDimension()
.MinorToMajorOnlyInLayout()(instruction_subshape,
buffer->shape())) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
instruction->name(), absl::StrJoin(index, ","),
buffer->ToString(),
ShapeUtil::HumanStringWithLayout(instruction_subshape),
ShapeUtil::HumanStringWithLayout(buffer->shape()));
}
}
}
return Status::OK();
}));
// Verify instructions that have special layout constraints.
switch (instruction->opcode()) {
case HloOpcode::kCall:
TF_RETURN_IF_ERROR(CheckCallLayout(
instruction,
FindOrDie(computation_layouts_, instruction->to_apply())));
break;
case HloOpcode::kCustomCall:
TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
break;
case HloOpcode::kFusion:
TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
break;
case HloOpcode::kParameter:
TF_RETURN_IF_ERROR(CheckParameterLayout(
instruction,
FindOrDie(computation_layouts_, instruction->parent())));
break;
case HloOpcode::kConstant:
TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
break;
case HloOpcode::kWhile:
TF_RETURN_IF_ERROR(CheckWhileLayout(
instruction,
FindOrDie(computation_layouts_, instruction->while_condition()),
FindOrDie(computation_layouts_, instruction->while_body())));
break;
case HloOpcode::kConditional: {
std::vector<ComputationLayout> branch_computation_layouts;
for (auto branch_computation : instruction->branch_computations()) {
branch_computation_layouts.emplace_back(
FindOrDie(computation_layouts_, branch_computation));
}
TF_RETURN_IF_ERROR(CheckConditionalLayout(
instruction, absl::MakeSpan(branch_computation_layouts)));
break;
}
default:
break;
}
}
}
// Finally verify the result layout, if set, matches the layout of the entry
// computation root.
const ShapeLayout& result_layout =
FindOrDie(computation_layouts_, module->entry_computation())
.result_layout();
if (result_layout.LayoutIsSet()) {
TF_RET_CHECK(
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
module->result_shape(), result_layout.shape()));
}
return Status::OK();
}
LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
saved_entry_computation_layout_(*entry_computation_layout),
channel_layout_constraints_(channel_constraints),
instruction_can_change_layout_func_(
std::move(instruction_can_change_layout_func)) {
if (channel_layout_constraints_ != nullptr) {
// Save a copy of the input ChannelLayoutConstraints so that we can reset it
// if we have to undo previous operations (ClearPreviousPassSideEffects()).
channel_constraints_ = *channel_layout_constraints_;
}
VLOG(1) << "Entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
const Layout& output_layout, const HloInstruction* instruction,
int64 operand_no) {
const HloInstruction* operand = instruction->operand(operand_no);
CHECK(instruction->shape().IsArray());
CHECK(operand->shape().IsArray());
if (!ShapeUtil::IsScalar(operand->shape()) &&
operand->shape().rank() == instruction->shape().rank() &&
!instruction_can_change_layout_func_(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
//
// For elementwise operations, using the same layout for the operands and
// the result also has the following benefits:
// 1) the elementwise operation can reuse its operand's buffer, and
// 2) the input and output elements can reuse the same linear index.
return absl::make_unique<Layout>(output_layout);
}
if (instruction->opcode() == HloOpcode::kReshape) {
// Prefer the operand layout that makes the reshape an bitcast. If any
// dimension bound is 1 in the operand shape, there may be several such
// layouts. So if 'output_layout' is the default layout, try if the
// reshape is a bitcast when using the same layout. This may avoid copy
// operations. For similar reasons, if the operand and output have the same
// rank, try to match the operand's layout to the output.
if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
ShapeUtil::TrueRank(instruction->shape()) == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
return nullptr;
}
const Shape& output_shape = instruction->shape();
Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
LayoutUtil::MinorToMajor(output_layout));
Shape operand_shape = operand->shape();
*operand_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(operand_shape);
auto aligned_operand_shape =
ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
if (aligned_operand_shape) {
auto operand_layout = aligned_operand_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
return absl::make_unique<Layout>(operand_layout);
}
}
if (instruction->opcode() == HloOpcode::kTranspose) {
// Pick the operand layout that makes the transpose a bitcast.
int64 rank = instruction->shape().rank();
std::vector<int64> new_minor_to_major(rank);
for (int64 i = 0; i < rank; ++i) {
int64 output_dim = LayoutUtil::Minor(output_layout, i);
int64 operand_dim = instruction->dimensions(output_dim);
new_minor_to_major[i] = operand_dim;
}
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
return absl::make_unique<Layout>(operand_layout);
}
return nullptr;
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
const Layout& operand_layout, const HloInstruction* user,
int64 operand_no) {
const HloInstruction* operand = user->operand(operand_no);
CHECK(user->shape().IsArray() && operand->shape().IsArray());
if (!ShapeUtil::IsScalar(operand->shape()) &&
operand->shape().rank() == user->shape().rank() &&
!instruction_can_change_layout_func_(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
}
if (user->opcode() == HloOpcode::kReshape) {
// Prefer the user layout that makes the reshape an bitcast. If any
// dimension bound is 1 in the user shape, there may be several such
// layouts. So if 'operand_layout' is the default layout, try if the
// reshape is a bitcast when using the same layout. This may avoid copy
// operations. For similar reasons, if the operand and output have the same
// rank, try to match the outputs's layout to the operand.
if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
ShapeUtil::TrueRank(user->shape()) == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
return nullptr;
}
Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
operand->shape().element_type(),
AsInt64Slice(operand->shape().dimensions()),
LayoutUtil::MinorToMajor(operand_layout));
Shape output_shape = user->shape();
*output_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(output_shape);
auto aligned_user_shape =
ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
if (aligned_user_shape) {
auto user_layout = aligned_user_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
return absl::make_unique<Layout>(user_layout);
}
}
if (user->opcode() == HloOpcode::kTranspose) {
// Pick the user layout that makes the transpose a bitcast.
int64 rank = user->shape().rank();
std::vector<int64> new_minor_to_major(rank);
auto inverse_dimensions = InversePermutation(user->dimensions());
for (int64 i = 0; i < rank; ++i) {
int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
int64 user_dim = inverse_dimensions[operand_dim];
new_minor_to_major[i] = user_dim;
}
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
return absl::make_unique<Layout>(user_layout);
}
return nullptr;
}
Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
// Gathers all initial constraints in a worklist and propagates them in
// depth-first order. DFS order seems to be better than BFS because a
// constraint is propagated as far as possible before propagating unrelated
// constraints which makes it less likely that conflicting constraints will be
// propagated to instructions. However, we should experiment with other orders
// too.
std::deque<const LayoutConstraint*> worklist;
// Lambda for moving newly added constraints to the worklist.
auto add_new_constraints_to_worklist = [constraints, &worklist]() {
// Add constraints to the front of the deque for DFS ordering.
for (auto* constraint : constraints->ConsumeAddedConstraints()) {
if (constraint->dfs()) {
worklist.push_front(constraint);
} else {
worklist.push_back(constraint);
}
}
};
add_new_constraints_to_worklist();
while (!worklist.empty()) {
const LayoutConstraint* layout_constraint = worklist.front();
worklist.pop_front();
VLOG(2) << "Propagating " << layout_constraint->ToString()
<< " to its neighbors.";
if (auto* buffer_constraint =
dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
TF_RETURN_IF_ERROR(
PropagateBufferConstraint(*buffer_constraint, constraints));
} else if (auto* operand_constraint =
dynamic_cast<const OperandLayoutConstraint*>(
layout_constraint)) {
TF_RETURN_IF_ERROR(
PropagateOperandConstraint(*operand_constraint, constraints));
} else if (auto* result_constraint =
dynamic_cast<const ResultLayoutConstraint*>(
layout_constraint)) {
TF_RETURN_IF_ERROR(
PropagateResultConstraint(*result_constraint, constraints));
} else {
LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
}
add_new_constraints_to_worklist();
}
return Status::OK();
}
namespace {
// Returns a vector containing all array-shaped uses (instruction and operand
// number) of the given logical buffer or its aliases.
std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
const LogicalBuffer& buffer,
const TuplePointsToAnalysis& points_to_analysis) {
CHECK(buffer.IsArray());
std::vector<std::pair<const HloInstruction*, int64>> uses;
for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
if (!buffer_alias.instruction()->shape().IsArray()) {
continue;
}
// This alias must be the top-level (index == {}) of the instruction's
// result because the instruction produces an array.
CHECK(buffer_alias.index().empty());
// Add all uses of the instruction's output.
for (const HloInstruction* user : buffer_alias.instruction()->users()) {
for (int64 operand_no :
user->OperandIndices(buffer_alias.instruction())) {
uses.emplace_back(user, operand_no);
}
}
}
return uses;
}
} // namespace
Status LayoutAssignment::PropagateUseConstraintToDefs(
const ShapeLayout& shape_layout, const HloInstruction* instruction,
LayoutConstraints* constraints) {
// Try to set all logical buffers which may be sources of the given operand to
// match the given layout.
const PointsToSet& points_to_set =
constraints->points_to_analysis().GetPointsToSet(instruction);
return points_to_set.ForEachElementWithStatus(
[&shape_layout, constraints](
const ShapeIndex& index,
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
for (const LogicalBuffer* buffer : buffers) {
if (constraints->BufferLayout(*buffer) == nullptr &&
buffer->shape().IsArray()) {
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
*buffer, /*mandatory=*/true));
}
}
}
return Status::OK();
});
}
namespace {
// A transpose or a reshape that only changes trivial dimensions have meaningful
// layouts that are valuable to propagate in a depthfirst manner to avoid
// unassigned layouts in the graph.
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
bool forward_propagation = true) {
switch (hlo.opcode()) {
case HloOpcode::kFusion:
return hlo.IsCustomFusion();
case HloOpcode::kGather:
return true;
case HloOpcode::kReshape:
return hlo.operand(0)->shape().rank() == 1 ||
(forward_propagation &&
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
case HloOpcode::kScatter:
case HloOpcode::kTranspose:
return true;
default:
return false;
}
}
} // namespace
Status LayoutAssignment::PropagateOperandConstraint(
const OperandLayoutConstraint& operand_constraint,
LayoutConstraints* constraints) {
// Try to set the layout of the logical buffers in the given operand to match
// the constrained layout. This avoids copies.
TF_RETURN_IF_ERROR(
PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
operand_constraint.operand(), constraints));
// For array-shaped operands and user instructions try to pick a minimum cost
// layout. For example, if the operand of an elementwise instruction is
// constrained to a certain layout we want the output of the instruction to
// have the same layout.
//
// If the user is not array-shaped, we still want to propagate the layout
// to siblings if the instruction can't change layout. This is to represent
// the information that non-layout-changing instructions should have the same
// layout for the operands with the same ranks.
const HloInstruction* operand = operand_constraint.operand();
const HloInstruction* user = operand_constraint.instruction();
if (!operand->shape().IsArray()) {
return Status::OK();
}
if (user->opcode() == HloOpcode::kAllReduce) {
const auto shape_index =
user->operand_count() == 1
? ShapeIndex()
: ShapeIndex({operand_constraint.operand_no()});
TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(
user, shape_index));
const BufferLayoutConstraint* constraint =
constraints->GetBufferLayoutConstraint(*buffer);
if (constraint == nullptr) {
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
operand_constraint.shape_layout().layout(), *buffer,
/*mandatory=*/false));
}
}
if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
return Status::OK();
}
// Only try to choose a low cost layout if the instruction 'user' defines its
// output (ie, doesn't forward a buffer from elsewhere).
if (constraints->OperandBufferForwarded(user,
operand_constraint.operand_no())) {
return Status::OK();
}
int64 operand_rank = operand->shape().rank();
if (operand_rank <= 1) {
return Status::OK();
}
// Propagate layouts between operands of the same instruction. This is a
// constraint on non-layout-changing instructions.
if (!instruction_can_change_layout_func_(user)) {
// Only propgate the layout of the largest concatenate operand.
if (user->opcode() == HloOpcode::kConcatenate) {
for (int64 operand_no = 0; operand_no < user->operand_count();
++operand_no) {
const HloInstruction* sibling = user->operand(operand_no);
if (sibling == operand) {
continue;
}
if (sibling->shape().dimensions(user->concatenate_dimension()) >
operand->shape().dimensions(user->concatenate_dimension())) {
return Status::OK();
}
}
}
// Make sure all siblings have the same layout as the operand.
for (int64 operand_no = 0; operand_no < user->operand_count();
++operand_no) {
if (user->operand(operand_no) == operand) {
continue;
}
const HloInstruction* sibling = user->operand(operand_no);
const int64 sibling_rank = sibling->shape().rank();
if (sibling_rank <= 1) {
continue;
}
if (operand_rank != sibling_rank) {
continue;
}
const OperandLayoutConstraint* constraint =
constraints->GetOperandLayoutConstraint(user, operand_no);
if (constraint != nullptr) {
// Due to the DFS of the propagation we can end up here when operand_no
// has a layout set that hasn't been propagated yet (is still on the
// stack of layouts to propagate).
// We can continue here and leave the operands with different layouts,
// as we will either:
// - overwrite the current operand when the DFS gets back to propagating
// operand(operand_no) to its siblings
// - overwrite operand(operand_no)'s layout with a mandatory layout if
// we continue to propagate our layout to the result, and then
// backwards into all operands (if the result is an array of rank > 1)
continue;
}
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
operand_constraint.shape_layout().layout(), user, operand_no,
/*mandatory=*/false));
}
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
user->shape(),
[&](const Shape& subshape, const ShapeIndex& shape_index) {
if (subshape.IsTuple()) {
return Status::OK();
}
if (subshape.rank() <= 1) {
return Status::OK();
}
// Assign the right layout to input fusion of higher rank reduce
// operations.
if (subshape.rank() != operand->shape().rank()) {
return Status::OK();
}
if (!constraints->points_to_analysis()
.InstructionDefinesBufferAtIndex(user, shape_index)) {
return Status::OK();
}
// TODO(b/67641796): Are there cases except fusion that use this code
// path?
TF_ASSIGN_OR_RETURN(
const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(
user, shape_index));
// Make sure the output has the same layout as the operand.
const BufferLayoutConstraint* constraint =
constraints->GetBufferLayoutConstraint(*buffer);
// If we already have a constraint for the buffer it was assigned but
// hasn't propagated yet. This can happen with diamond-shaped graphs
// where one path is first evaluated in depth-first order (we're here)
// and the other path is propagated later. We don't set the layout
// here as it will always be overwritten later.
if (constraint == nullptr) {
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
operand_constraint.shape_layout().layout(), *buffer,
/*mandatory=*/false));
}
return Status::OK();
}));
return Status::OK();
}
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
if (subshape.IsTuple()) {
return Status::OK();
}
if (subshape.rank() <= 1) {
return Status::OK();
}
if (!constraints->points_to_analysis().InstructionDefinesBufferAtIndex(
user, shape_index)) {
return Status::OK();
}
TF_ASSIGN_OR_RETURN(
const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(user,
shape_index));
if (constraints->BufferLayout(*buffer) == nullptr ||
!constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
operand_constraint.shape_layout().layout(), user,
operand_constraint.operand_no());
if (layout != nullptr) {
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
*layout, *buffer,
/*mandatory=*/user->opcode() == HloOpcode::kReduce,
/*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
}
}
return Status::OK();
}));
return Status::OK();
}
Status LayoutAssignment::PropagateBufferConstraintToOperands(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) {
VLOG(5) << "PropagateBufferConstraintToOperands: "
<< buffer_constraint.ToString();
const LogicalBuffer& buffer = buffer_constraint.buffer();
const HloInstruction* instruction = buffer.instruction();
if (IsAtMostRank1(instruction->shape())) {
return Status::OK();
}
if (instruction->opcode() == HloOpcode::kAllReduce) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
buffer_constraint.layout(), instruction,
instruction->operand_count() == 1 ? 0 : buffer.index()[0],
/*mandatory=*/true));
return Status::OK();
}
for (int64 operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const HloInstruction* operand = instruction->operand(operand_no);
if (IsAtMostRank1(operand->shape())) {
continue;
}
if (!instruction_can_change_layout_func_(instruction)) {
// Copy the layout to the operand.
if (buffer.IsArray() && operand->shape().IsArray() &&
operand->shape().rank() ==
LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
buffer_constraint.layout(), instruction, operand_no,
/*mandatory=*/true));
}
} else {
if (!buffer.IsTopLevel() ||
!instruction->operand(operand_no)->shape().IsArray()) {
continue; // Don't touch buffers that are internal to a tuple.
}
VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
<< instruction->ToShortString();
// Assign a layout if there is no constraint already.
const OperandLayoutConstraint* constraint =
constraints->GetOperandLayoutConstraint(instruction, operand_no);
if (constraint == nullptr || !constraint->mandatory()) {
std::unique_ptr<Layout> operand_layout =
ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
instruction, operand_no);
if (operand_layout != nullptr) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
*operand_layout, instruction, operand_no, /*mandatory=*/false,
/*dfs=*/
InstructionShouldPropagateDepthFirst(
*instruction, /*forward_propagation=*/false)));
}
} else {
VLOG(6) << "Operand already has a constraint "
<< constraint->ToString();
}
}
}
return Status::OK();
}
Status LayoutAssignment::PropagateBufferConstraint(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) {
// Only propagate array layouts.
const LogicalBuffer& buffer = buffer_constraint.buffer();
if (!buffer.IsArray()) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
PropagateBufferConstraintToUses(buffer_constraint, constraints));
return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
}
Status LayoutAssignment::PropagateBufferConstraintToUses(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) {
const LogicalBuffer& buffer = buffer_constraint.buffer();
TF_RET_CHECK(buffer.IsArray());
// Propagate the layout to all array uses of the logical buffer. This skips
// uses of the buffer where the buffer is the element of a tuple.
for (const auto& user_operand_no :
GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
const HloInstruction* user = user_operand_no.first;
int64 operand_no = user_operand_no.second;
// Only add an operand constraint if the user does not forward the buffer
// because this case is not handled is SetOperandLayout.
if (constraints->OperandLayout(user, operand_no) == nullptr &&
!constraints->OperandBufferForwarded(user, operand_no)) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
}
}
// Propagate to backedges of kWhile.
CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
if (node.caller_callsites().size() != 1) {
return Status::OK();
}
const HloInstruction* parent = node.caller_callsites()[0].instruction();
if (parent->opcode() != HloOpcode::kWhile) {
return Status::OK();
}
for (HloInstruction* user : buffer.instruction()->users()) {
if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
continue;
}
if (user->parent()->root_instruction() == user) {
VLOG(3) << "Propagating layout through backedge"
<< buffer_constraint.layout().ToString();
int64 index = user->operand_index(buffer.instruction());
TF_ASSIGN_OR_RETURN(
auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
user->parent()->parameter_instruction(0), {index}));
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
buffer_constraint.layout(), *buffer, /*mandatory=*/false));
}
}
return Status::OK();
}
Status LayoutAssignment::PropagateResultConstraint(
const ResultLayoutConstraint& layout_constraint,
LayoutConstraints* constraints) {
// Propagate the use constraint of the root instruction up to the logical
// buffers which make up the result.
return PropagateUseConstraintToDefs(
layout_constraint.shape_layout(),
constraints->computation()->root_instruction(), constraints);
}
namespace {
// Infers the layout of the array at the given index in the given instruction's
// output using points-to analysis. Precondition: The given instruction must
// not produce this array value (that is, the array is forwarded from the
// instruction's operands).
StatusOr<Layout> InferArrayLayout(
const TuplePointsToAnalysis& points_to_analysis,
HloInstruction* instruction, const ShapeIndex& index) {
// This function should only be called for array shapes which don't yet have
// layouts.
const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
TF_RET_CHECK(subshape.IsArray());
TF_RET_CHECK(!subshape.has_layout());
// The instruction should not define the buffer at this index.
TF_RET_CHECK(
!points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
<< instruction->ToString();
const auto& source_buffers =
points_to_analysis.GetPointsToSet(instruction).element(index);
TF_RET_CHECK(!source_buffers.empty());
// Verify the layout is the same for every LogicalBuffer which this location
// ('instruction' and 'index') points to.
const Layout* first_buffer_layout = nullptr;
for (const LogicalBuffer* source_buffer : source_buffers) {
if (!source_buffer->shape().has_layout()) {
// This should not happen because we've assigned layouts to all
// instructions preceding this one.
return InternalError("LogicalBuffer %s does not have a layout",
source_buffer->ToString());
}
if (first_buffer_layout == nullptr) {
first_buffer_layout = &source_buffer->shape().layout();
} else if (!Layout::Equal().MinorToMajorOnly()(
source_buffer->shape().layout(), *first_buffer_layout)) {
// The points-to set is ambiguous for this index and the different source
// buffers have different layouts. This case is possible in valid XLA
// computations because we do not propagate BufferLayoutConstraints to all
// LogicalBuffers which may alias the constrained LogicalBuffer at some
// point in the computation.
return FailedPrecondition(
"Array at index {%s} in instruction %s aliases buffers %s "
"and %s which have different layouts",
absl::StrJoin(index, ","), instruction->name(),
source_buffers[0]->ToString(), source_buffer->ToString());
}
}
return *first_buffer_layout;
}
// For fusion instructions, set the layout of each fused parameter instruction
// to match the layout of its corresponding fusion instruction operand. Also,
// set the layout of the fused root to match the layout of the fusion
// instruction itself.
Status SetFusionLayouts(HloInstruction* fusion) {
TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
for (auto* fused_instruction :
fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
if (fused_instruction->opcode() == HloOpcode::kParameter) {
const HloInstruction* fusion_operand =
fusion->operand(fused_instruction->parameter_number());
DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
fused_instruction->shape()));
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
fusion_operand->shape(), fused_instruction->mutable_shape()));
} else if (fused_instruction == fusion->fused_expression_root()) {
// The layout of the root of the fused expression must match the fusion
// instruction layout.
DCHECK(
ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
fusion->shape(), fused_instruction->mutable_shape()));
} else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
// A GTE inherits its layout from its operand (which should ultimately be
// a parameter).
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
fused_instruction->operand(0)->shape().tuple_shapes(
fused_instruction->tuple_index()),
fused_instruction->mutable_shape()));
} else if (fused_instruction->opcode() == HloOpcode::kConstant) {
// Give constants the layout of their literal.
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
fused_instruction->literal().shape(),
fused_instruction->mutable_shape()));
} else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
// Nop; leave the infeed layout alone.
} else if (!fusion->IsCustomFusion()) {
// Other instructions don't have layouts inside of fusion nodes.
// But do not clear layouts for other instructions in custom fusion nodes.
LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
}
}
return Status::OK();
}
} // namespace
Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
HloComputation* computation) {
VLOG(2) << "Assigning layouts to computation: " << computation->name();
XLA_VLOG_LINES(2, computation->ToString());
XLA_VLOG_LINES(2, constraints.ToString());
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
// Set the layouts of the array shapes this instruction defines as indicated
// by the respective BufferLayoutConstraints. Any array shapes in the output
// of the instruction which are not defined by the instruction (eg, array
// elements in a Tuple instruction) will be assigned below via inference.
for (const LogicalBuffer* buffer :
constraints.points_to_analysis().GetBuffersDefinedByInstruction(
instruction)) {
if (!buffer->shape().IsArray()) {
continue;
}
TF_RET_CHECK(buffer->instruction() == instruction);
const Layout* buffer_layout = constraints.BufferLayout(*buffer);
TF_RET_CHECK(buffer_layout != nullptr);
if (instruction->opcode() == HloOpcode::kConstant) {
// For constants, we also need to change the layout of the internal
// literal.
instruction->RelayoutConstant(*buffer_layout, buffer->index());
} else {
Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
instruction->mutable_shape(), buffer->index());
*buffer_subshape->mutable_layout() = *buffer_layout;
}
}
// Any remaining layouts in the output of the instruction must be
// inferrable using points-to analysis.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
instruction->mutable_shape(),
[instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
if (subshape->has_layout() || !subshape->IsArray()) {
return Status::OK();
}
// Set Layout of subshape to match layout of LogicalBuffer which
// produces it.
TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
InferArrayLayout(constraints.points_to_analysis(),
instruction, index));
return Status::OK();
}));
// Create a copy of an operand if the operand instruction's layout does not
// match the use constraint (OperandLayoutConstraint).
for (int64 operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const ShapeLayout* operand_layout =
constraints.OperandLayout(instruction, operand_no);
if (operand_layout != nullptr) {
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
instruction, operand_no));
}
}
// Fusion instructions require some layouts to be set on fused instructions
// inside the fusion instruction.
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
}
// Execute extra verification step once the layout has been finalized.
TF_RETURN_IF_ERROR(Verify(instruction));
// Shape must be valid.
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
return Status::OK();
}
Status LayoutAssignment::CalculateComputationLayout(
HloComputation* computation) {
ComputationLayout computation_layout(computation->ComputeProgramShape(),
/*ignore_layouts=*/false);
InsertOrDie(&computation_layouts_, computation, computation_layout);
VLOG(2) << " Calculated ComputationLayout = "
<< computation_layout.ToString();
return Status::OK();
}
Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// Clear existing layouts of the instructions. All layouts must be assigned
// by the LayoutAssignment pass, except for those on parameters, the
// computation result, and a couple special cases. The former two are
// specified in computation_layout. Clearing the layouts here avoids hiding
// potential bugs in the layout assignment pass that may accidentally use the
// existing layout.
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBitcast) {
// bitcasts are inherently layout sensitive and so a bitcast instruction
// present in the IR before layout assignment is a bug.
return InternalError(
"Unexpected bitcast operation seen during layout assignment: %s.",
instruction->ToString());
}
// Some instructions carry mandatory layouts in their shape.
if (instruction->opcode() != HloOpcode::kInfeed &&
!IsLayoutConstrainedCustomCall(instruction) &&
!IsLayoutConstrainedCollective(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
return Status::OK();
}
Status LayoutAssignment::RunOnComputation(
ComputationLayout* computation_layout, HloComputation* computation,
ChannelLayoutConstraints* channel_constraints) {
VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
<< ")";
// Must be run before clearing layouts.
TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
if (computation_layout != nullptr) {
auto it = computation_layouts_.find(computation);
if (it == computation_layouts_.end()) {
VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
computation_layouts_.emplace(computation, *computation_layout);
} else {
TF_RET_CHECK(computation_layout == &it->second ||
computation_layout == entry_computation_layout_);
VLOG(2) << " Existing ComputationLayout = "
<< computation_layout->ToString();
}
} else {
VLOG(2) << " No ComputationLayout specified (will be calculated)";
}
// Construct LayoutConstraints with all layout constraints of the computation.
LayoutConstraints constraints(*points_to_analysis_, computation);
// Add constraints required for correctness on all backends (eg, entry
// parameter layout constraints).
TF_RETURN_IF_ERROR(AddMandatoryConstraints(
computation_layout, channel_constraints, computation, &constraints));
// Add any backend-specific constraints.
TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
// Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
// Prior to applying default layouts, we take note of all HLO instructions
// which lack a layout constraint.
for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
unconstrained_layout_instructions_.insert(
points_to_analysis_->GetBuffer(buffer_id).instruction());
}
// While any unconstrained buffers remain, pick an arbitrary buffer, give it a
// layout and propagate the change.
while (!constraints.unconstrained_buffer_ids().empty()) {
int unconstrained_count = constraints.unconstrained_buffer_ids().size();
// Arbitrarily pick the first unconstrained buffer and give it the default
// layout (or the literal layout, in case of constants). By construction
// unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
*constraints.unconstrained_buffer_ids().begin());
const HloInstruction* instruction = buffer.instruction();
Layout new_layout =
instruction->opcode() == HloOpcode::kConstant
? ShapeUtil::GetSubshape(instruction->literal().shape(),
buffer.index())
.layout()
: GetUnconstrainedLayout(buffer);
TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
/*mandatory=*/false));
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
// To verify progress has been made, check that the number of unconstrained
// buffers has been reduced.
CHECK_LT(constraints.unconstrained_buffer_ids().size(),
unconstrained_count);
}
// All logical buffers should have constraints at this point. All that
// remains is assign the constraints to the buffers and infer layouts for
// aliased buffers.
TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
// If the computation layout wasn't specified, now it is the time to compute
// it according to the parameters and root instruction layouts.
// This allows the first pass through this API to record the best flowing
// layout to parameters and root instruction.
if (computation_layout == nullptr) {
TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
}
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
if (channel_constraints != nullptr) {
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
// Copy the root instruction's result if its layout does not match the result
// layout constraint.
if (constraints.ResultLayout() != nullptr) {
// Layout assignment at this point only does minor-to-major assignment so
// tiling info should be ignored here for comparison.
if (!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
if (conditional_mismatch_.count(computation) > 0) {
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
FindOrDie(conditional_mismatch_, computation).result_layout();
}
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
} else {
// Copy the specified tiling info.
auto assign_tiling = [&constraints](xla::Shape* subshape,
const xla::ShapeIndex& index) {
if (subshape->IsArray()) {
const Shape& result_shape = ShapeUtil::GetSubshape(
constraints.ResultLayout()->shape(), index);
subshape->mutable_layout()->mutable_tiles()->assign(
result_shape.layout().tiles().begin(),
result_shape.layout().tiles().end());
}
};
xla::ShapeUtil::ForEachMutableSubshape(
computation->root_instruction()->mutable_shape(), assign_tiling);
}
}
return Status::OK();
}
Status LayoutAssignment::ConstrainChannelLayouts(
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints) {
auto get_channel_constraints = [&](const HloInstruction* instruction) {
return IsHostSendRecv(instruction) ? &host_channel_constraints_
: channel_constraints;
};
// We go through the kRecvDone before. These must either impose their layout,
// or find a matching one already existing (ConstrainChannel() returns
// nullptr).
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kRecvDone) {
const Layout* layout =
get_channel_constraints(instruction)
->ConstrainChannel(
*instruction->channel_id(),
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
TF_RET_CHECK(layout == nullptr)
<< instruction->ToString()
<< " cannot constrain layout as it was set to "
<< LayoutUtil::HumanString(*layout);
}
}
// After that we go through the kSend. These are likely going to have a kCopy
// as operand (otherwise we add it), so in case the constrained layout does
// not match, we can change the kCopy layout (and the kSend one as well).
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
HloInstruction* operand = instruction->mutable_operand(0);
get_channel_constraints(instruction)
->ConstrainChannel(*instruction->channel_id(),
operand->shape().layout());
} else if (instruction->IsCrossModuleAllReduce()) {
get_channel_constraints(instruction)
->ConstrainChannel(instruction->channel_id().value(),
instruction->shape().layout());
}
}
return Status::OK();
}
Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
for (const auto& buffer : alias_analysis->buffers()) {
// First go through values to collect the memory spaces.
int64 buffer_memory_space = Layout::kDefaultMemorySpace;
for (auto value : buffer.values()) {
const Shape& defining_shape = value->defining_position().shape();
int64 memory_space = defining_shape.layout().memory_space();
if (memory_space != Layout::kDefaultMemorySpace) {
if (buffer_memory_space != Layout::kDefaultMemorySpace &&
memory_space != buffer_memory_space) {
return InternalError(
"Buffer %d (%s) has conflicting memory spaces: %d and %d.",
buffer.id(), value->ToShortString(), buffer_memory_space,
memory_space);
}
buffer_memory_space = memory_space;
}
}
// If we encounter a memory space other than the default, then propagate all
// the positions with the buffer's memory space.
if (buffer_memory_space != Layout::kDefaultMemorySpace) {
for (auto value : buffer.values()) {
for (auto& position : value->positions()) {
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
shape->mutable_layout()->set_memory_space(buffer_memory_space);
}
}
}
}
return Status::OK();
}
Status LayoutAssignment::PropagateComputationLayouts(
HloComputation* computation, ComputationLayout* computation_layout) {
ComputationLayout computed_computation_layout(
computation->ComputeProgramShape(),
/*ignore_layouts=*/false);
for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
bool needs_assign = false;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
param_layout->shape(),
[&](const Shape& subshape, const ShapeIndex& shape_index) {
if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
return Status::OK();
}
if (!subshape.has_layout()) {
needs_assign = true;
return Status::OK();
}
const auto& computed_subshape = ShapeUtil::GetSubshape(
computed_computation_layout.parameter_shape(i), shape_index);
if (subshape.layout() != computed_subshape.layout()) {
return InternalError(
"Assigned parameter shape %s does not match layout of "
"computation shape: %s",
computed_computation_layout.ToString(),
computation_layout->ToString());
}
return Status::OK();
}));
if (needs_assign) {
VLOG(4) << "Assigning layout to parameter " << i << " of computation "
<< computation->name() << ": "
<< computed_computation_layout.parameter_layout(i).ToString();
*param_layout = computed_computation_layout.parameter_layout(i);
}
}
ShapeLayout* result_layout = computation_layout->mutable_result_layout();
if (!result_layout->LayoutIsSet()) {
VLOG(4) << "Assigning result layout of computation " << computation->name()
<< ": " << computed_computation_layout.result_layout().ToString();
*result_layout = computed_computation_layout.result_layout();
} else {
TF_RET_CHECK(
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
computed_computation_layout.result_layout().shape(),
result_layout->shape()));
}
return Status::OK();
}
StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
VLOG(2) << "Running layout assignment on module " << module->name();
TF_RETURN_IF_ERROR(Init());
call_graph_ = CallGraph::Build(module);
auto computations = module->computations();
// Add copy to the operand of Send instructions, since we cannot call
// SetOperandLayout on Send instructions as it aliases its input to the
// output.
//
// TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
// operand buffers that aliases with the output.
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
}
}
}
// Clone Conditional computations with multiple callsites.
for (HloComputation* computation : computations) {
CallGraphNode& node = call_graph_->GetNode(computation);
if (node.caller_callsites().size() == 1) {
continue;
}
if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
return caller.instruction()->opcode() == HloOpcode::kConditional;
})) {
continue;
}
for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) {
HloInstruction* caller = node.caller_callsites()[i].instruction();
if (caller->opcode() == HloOpcode::kConditional) {
for (int64 k = 0; k < caller->branch_count(); ++k) {
if (computation == caller->branch_computation(k)) {
caller->set_branch_computation(
k, module->AddEmbeddedComputation(computation->Clone()));
break;
}
}
}
}
}
// Verify computation layout is sane.
const HloComputation* entry = module->entry_computation();
TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
entry->num_parameters());
for (int64 i = 0; i < entry->num_parameters(); ++i) {
TF_RET_CHECK(
ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
entry->parameter_instruction(i)->shape()));
}
TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
entry->root_instruction()->shape()));
// We do two passes. The first one we pass a nullptr ComputationLayout to
// the RunOnComputation() calls (for non entry computations), and we register
// the ComputationLayout which are naturally flowing in DFS fashion to the
// parameters and root instruction.
// Walking in DFS mode though, means that we can end up with incorrect layouts
// when seen from an outer instruction, which has across-computation
// constraints to impose.
// For example, the kWhile instruction needs to enforce the same layouts for
// the parameters and root of the body, as well as the condition parameters.
// Similarly, the kConditional instruction needs to enforce the same layouts
// for the root of the true and false computations.
// So in the first pass, while allowing the layouts to flow to parameters and
// root, we also fix up the eventually inconsistent ComputationLayout, which
// will be then made mandatory by the second pass.
for (int64 i = 0; i < 2; ++i) {
VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
points_to_analysis_ = std::move(points_to_analysis);
for (auto* computation : module->MakeComputationPostOrder()) {
if (computation->IsFusionComputation()) {
continue;
}
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
module->entry_computation(),
channel_layout_constraints_));
} else {
ComputationLayout* computation_layout =
(i == 0 || conditional_mismatch_.count(computation) > 0)
? nullptr
: &FindOrDie(computation_layouts_, computation);
TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
channel_layout_constraints_));
}
}
}
TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
entry_computation_layout_));
TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
TF_RETURN_IF_ERROR(CheckLayouts(module));
// All layouts are reset then reassigned by this pass.
return true;
}
/* static */
bool LayoutAssignment::InstructionCanChangeLayout(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAddDependency:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConditional:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kAllGather:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFft:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kPad:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kScatter:
case HloOpcode::kSelect:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kPopulationCount:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
case HloOpcode::kSetDimensionSize:
// AllReduce is variadic so it needs to be careful to assign the same layout
// to the corresponding input argument and Tuple index.
case HloOpcode::kAllReduce:
return false;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kFusion:
case HloOpcode::kGather:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kPartitionId:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReduce:
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kAfterAll:
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kGetDimensionSize:
return true;
}
}
/* static */
bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
if (shape.IsArray()) {
return shape.rank() <= 1;
}
return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
return IsAtMostRank1(subshape);
});
}
Status LayoutAssignment::Init() {
computation_layouts_.clear();
conditional_mismatch_.clear();
*entry_computation_layout_ = saved_entry_computation_layout_;
return Status::OK();
}
Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
VLOG(5) << "Clearing previous side effects";
// Clear all the copies which have been added, and all the related
// instructions (like GTE and tuples).
int64 removed_copies = 0;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kCopy &&
added_copies_.contains(instruction)) {
VLOG(5) << "Removing added copy: " << instruction->ToString();
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
++removed_copies;
}
}
}
added_copies_.clear();
unconstrained_layout_instructions_.clear();
if (removed_copies > 0) {
TupleSimplifier tuple_simplifier;
HloDCE dce;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
call_graph_ = CallGraph::Build(module);
}
return Status::OK();
}
Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
int64 operand_number) {
HloInstruction* operand = instruction->mutable_operand(operand_number);
if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
HloInstruction* copy =
instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
operand->shape(), HloOpcode::kCopy, operand));
SetupCopiedInstruction(*operand, copy, {});
LayoutUtil::ClearLayout(copy->mutable_shape());
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
}
return Status::OK();
}
} // namespace xla