blob: 928a3130b24dcdadce70e98adc4acafa57dcad4c [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
#include <algorithm>
#include <functional>
#include <string>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
namespace xla {
namespace {
// Create a constant (recursively for tuples) of the given shape and add it to
// the computation.
HloInstruction* CreateConstant(const Shape& shape,
HloComputation* computation) {
if (shape.IsTuple()) {
std::vector<HloInstruction*> tuple_arguments(shape.tuple_shapes_size());
for (int index = 0; index < shape.tuple_shapes_size(); ++index) {
tuple_arguments[index] =
CreateConstant(shape.tuple_shapes(index), computation);
}
return computation->AddInstruction(
HloInstruction::CreateTuple(tuple_arguments));
} else {
return computation->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(shape)));
}
}
// Extracts an instruction that satisfies filter from a fusion instruction.
// Returns nullptr if the fusion doesn't contain any instruction that satisfies
// filter.
const HloInstruction* ExtractInstruction(
const HloInstruction* hlo,
const std::function<bool(const HloInstruction*)>& filter) {
if (filter(hlo)) {
return hlo;
}
if (hlo->opcode() != HloOpcode::kFusion) {
return nullptr;
}
for (HloInstruction* inst :
hlo->fused_instructions_computation()->instructions()) {
if (filter(inst)) {
return inst;
}
}
return nullptr;
}
// Prints sub-expression rooted at inst for a given depth.
void PrintSubexpression(HloInstruction* inst, int depth) {
if (depth == 0) {
return;
}
for (auto* operand : inst->operands()) {
PrintSubexpression(operand, depth - 1);
}
VLOG(2) << inst->ToString();
}
bool IsConstantScalarInt(const HloInstruction* inst) {
return inst->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(inst->shape()) &&
inst->shape().IsInteger();
}
bool IsNotContainedInLoop(const HloInstruction& while_hlo,
const CallGraph& call_graph) {
const HloComputation* computation = while_hlo.parent();
while (!computation->IsEntryComputation()) {
auto& node = call_graph.GetNode(computation);
CHECK_EQ(node.caller_callsites().size(), 1)
<< "The module is not flattened!";
auto& callsite = node.caller_callsites()[0];
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Another while loop has been found traversing up the call tree.
return false;
}
computation = callsite.instruction()->parent();
}
// No calling while loops were found.
return true;
}
} // namespace
int GetLoopBound(const HloInstruction& while_hlo, const int default_loop_count,
const int max_loop_count) {
HloInstruction* condition = while_hlo.while_condition()->root_instruction();
if (condition->opcode() == HloOpcode::kCompare) {
int64_t value = 0;
Comparison::Direction cmp = condition->comparison_direction();
if ((cmp == Comparison::Direction::kLt ||
cmp == Comparison::Direction::kLe ||
cmp == Comparison::Direction::kNe) &&
IsConstantScalarInt(condition->operand(1))) {
value = *condition->operand(1)->literal().GetFirstInteger();
} else if ((cmp == Comparison::Direction::kGt ||
cmp == Comparison::Direction::kGe ||
cmp == Comparison::Direction::kNe) &&
IsConstantScalarInt(condition->operand(0))) {
value = *condition->operand(0)->literal().GetFirstInteger();
}
if (value > 0) {
// Caps to a max loop count to avoid long execution times.
return std::min(value, static_cast<int64_t>(max_loop_count));
}
}
return default_loop_count;
}
int GetLoopBoundWithOuterLoopMax(const HloInstruction& while_hlo,
const CallGraph& call_graph,
const int default_loop_count,
const int max_outer_loop_count,
const int max_loop_count) {
int loop_bound = GetLoopBound(while_hlo, default_loop_count, max_loop_count);
if (loop_bound > max_outer_loop_count) {
// First does the inexpensive loop bound check to avoid as many
// expensive graph traversals in IsNotContainedInLoop as possible.
if (IsNotContainedInLoop(while_hlo, call_graph)) {
return max_outer_loop_count;
}
}
return loop_bound;
}
Status HloControlFlowFlattening::FlattenWhileLoop(
HloInstruction* while_hlo, const CallGraph& call_graph) const {
CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
HloComputation* computation = while_hlo->parent();
// Add a new induction variable.
HloInstruction* initialization = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
// Create a new while operand with the induction variable added.
HloInstruction* old_tuple = while_hlo->mutable_operand(0);
HloInstruction* new_tuple =
TupleUtil::AppendSuffix(old_tuple, {initialization});
int new_tuple_size = new_tuple->shape().tuple_shapes().size();
TF_RETURN_IF_ERROR(while_hlo->ReplaceOperandWithDifferentShape(0, new_tuple));
auto change_op_shape = [&](HloInstruction* instruction) {
Shape* shape = instruction->mutable_shape();
CHECK(shape->IsTuple());
CHECK_EQ(shape->tuple_shapes().size(), new_tuple_size - 1);
Shape* subshape = shape->add_tuple_shapes();
return ShapeUtil::PopulateShape(S32, {}, subshape);
};
// Replace the given tuple-shaped instruction of size N in each of its
// non-get-tuple-element users with a new tuple instruction which has the
// first N - 1 elements.
auto replace_non_gte_users =
[](HloInstruction* new_tuple) -> StatusOr<HloInstruction*> {
CHECK(new_tuple->shape().IsTuple());
HloInstruction* prefix = nullptr;
std::vector<HloInstruction*> users(new_tuple->users());
for (HloInstruction* user : users) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
continue;
}
// Lazily extract the prefix on demand, reuse it as needed.
if (prefix == nullptr) {
prefix = TupleUtil::ExtractPrefix(
new_tuple, new_tuple->shape().tuple_shapes_size() - 1);
}
TF_RETURN_IF_ERROR(new_tuple->ReplaceUseWithDifferentShape(user, prefix));
}
return prefix;
};
{
// Add the new variable to the while loop condition.
HloComputation* condition = while_hlo->while_condition();
TF_RETURN_IF_ERROR(change_op_shape(condition->parameter_instruction(0)));
TF_RETURN_IF_ERROR(
replace_non_gte_users(condition->parameter_instruction(0)).status());
if (VLOG_IS_ON(2)) {
VLOG(2) << "Loop condition in " << while_hlo->parent()->name();
PrintSubexpression(condition->root_instruction(), /*depth=*/3);
}
const int loop_bound = GetLoopBoundWithOuterLoopMax(
*while_hlo, call_graph, while_execution_count_, max_outer_loop_count_,
max_loop_count_);
VLOG(1) << "loop_bound = " << loop_bound;
HloInstruction* limit = condition->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(loop_bound)));
Shape shape = initialization->shape();
HloInstruction* induction_variable =
condition->AddInstruction(HloInstruction::CreateGetTupleElement(
shape, condition->parameter_instruction(0), new_tuple_size - 1));
HloInstruction* compare =
condition->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), induction_variable, limit,
ComparisonDirection::kLt));
TF_RETURN_IF_ERROR(
condition->ReplaceInstruction(condition->root_instruction(), compare));
}
{
// Add the new variable to the while loop body.
HloComputation* body = while_hlo->while_body();
TF_RETURN_IF_ERROR(change_op_shape(body->parameter_instruction(0)));
TF_RETURN_IF_ERROR(
replace_non_gte_users(body->parameter_instruction(0)).status());
HloInstruction* old_root = body->root_instruction();
Shape shape = initialization->shape();
HloInstruction* induction_variable =
body->AddInstruction(HloInstruction::CreateGetTupleElement(
shape, body->parameter_instruction(0), new_tuple_size - 1));
HloInstruction* increment = body->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
induction_variable = body->AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, induction_variable, increment));
HloInstruction* new_root =
TupleUtil::AppendSuffix(old_root, {induction_variable});
body->set_root_instruction(new_root, /*accept_different_shape=*/true);
}
// Snapshot the users of while hlo before we add new users.
std::vector<HloInstruction*> while_users(while_hlo->users().begin(),
while_hlo->users().end());
// Take care of the users of this while loop.
TF_RETURN_IF_ERROR(change_op_shape(while_hlo));
TF_ASSIGN_OR_RETURN(HloInstruction * prefix,
replace_non_gte_users(while_hlo));
// If the while loop had been the root of its computation, make the prefix new
// root.
if (while_hlo->parent()->root_instruction() == while_hlo) {
// We need to set accept_different_shape=true to reset the root shape to the
// original, because we have already changed the shape of the old root
// (while).
if (prefix == nullptr) {
prefix = TupleUtil::ExtractPrefix(while_hlo, new_tuple_size - 1);
}
while_hlo->parent()->set_root_instruction(prefix,
/*accept_different_shape=*/true);
}
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemoveInfeed(
HloInstruction* infeed_hlo) const {
CHECK_EQ(infeed_hlo->opcode(), HloOpcode::kInfeed);
HloComputation* computation = infeed_hlo->parent();
CHECK_EQ(infeed_hlo->shape().tuple_shapes_size(), 2);
const Shape& infeed_shape = ShapeUtil::GetSubshape(infeed_hlo->shape(), {0});
HloInstruction* custom_call = computation->AddInstruction(
HloInstruction::CreateCustomCall(infeed_shape, {}, kNopCustomCallTarget));
// Create a new tuple consisting op the constant and the token that was
// originally the operand of infeed, and replace the infeed operation.
auto new_tuple = HloInstruction::CreateTuple(
{custom_call, infeed_hlo->mutable_operand(0)});
TF_RETURN_IF_ERROR(
computation->ReplaceWithNewInstruction(infeed_hlo, std::move(new_tuple)));
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemoveRecvDone(
HloInstruction* recv_done,
absl::flat_hash_set<HloInstruction*>* additional_removed) const {
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
CHECK_EQ(recv_done->operand_count(), 1);
HloInstruction* recv = recv_done->mutable_operand(0);
CHECK_EQ(recv->opcode(), HloOpcode::kRecv);
HloComputation* computation = recv_done->parent();
CHECK_EQ(recv_done->shape().tuple_shapes_size(), 2);
const Shape& recv_shape = ShapeUtil::GetSubshape(recv_done->shape(), {0});
HloInstruction* custom_call = computation->AddInstruction(
HloInstruction::CreateCustomCall(recv_shape, {}, kNopCustomCallTarget));
// Create a new tuple consisting op the constant and the token that was
// originally the operand of recv, and replace the recv operation.
auto new_tuple =
HloInstruction::CreateTuple({custom_call, recv->mutable_operand(0)});
TF_RETURN_IF_ERROR(
computation->ReplaceWithNewInstruction(recv_done, std::move(new_tuple)));
additional_removed->insert(recv);
TF_RETURN_IF_ERROR(computation->RemoveInstruction(recv));
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemoveOutfeed(
HloInstruction* outfeed_hlo) const {
CHECK_EQ(outfeed_hlo->opcode(), HloOpcode::kOutfeed);
HloComputation* computation = outfeed_hlo->parent();
// Replace the outfeed with a no-op custom call with side effect to ensure the
// operands aren't DCE'd.
HloInstruction* custom_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
outfeed_hlo->shape(), outfeed_hlo->operands(), "NopReturnToken"));
Cast<HloCustomCallInstruction>(custom_call)
->set_custom_call_has_side_effect(true);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(outfeed_hlo, custom_call));
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemoveSendDone(
HloInstruction* send_done,
absl::flat_hash_set<HloInstruction*>* additional_removed) const {
CHECK_EQ(send_done->opcode(), HloOpcode::kSendDone);
CHECK_EQ(send_done->operand_count(), 1);
HloInstruction* send = send_done->mutable_operand(0);
CHECK_EQ(send->opcode(), HloOpcode::kSend);
HloComputation* computation = send_done->parent();
HloInstruction* custom_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
send_done->shape(), send_done->operand(0)->operands(),
"NopReturnToken"));
Cast<HloCustomCallInstruction>(custom_call)
->set_custom_call_has_side_effect(true);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(send_done, custom_call));
additional_removed->insert(send);
TF_RETURN_IF_ERROR(computation->RemoveInstruction(send));
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemoveCollective(HloInstruction* hlo) const {
HloComputation* computation = hlo->parent();
HloInstruction* custom_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
hlo->shape(), hlo->operands(), kNopCustomCallTarget));
// Copy backend config. This is necessary for a collective op in megacore
// fusion.
custom_call->CopyBackendConfigFrom(hlo);
auto replaced_collective_op_str =
hlo->ToString(HloPrintOptions().Canonical());
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, custom_call));
custom_call->set_metadata_replaced_op(replaced_collective_op_str);
return ::tensorflow::OkStatus();
}
Status HloControlFlowFlattening::RemovePartitionOrReplicaId(
HloInstruction* hlo) const {
HloComputation* computation = hlo->parent();
HloInstruction* zero = CreateConstant(hlo->shape(), computation);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, zero));
return ::tensorflow::OkStatus();
}
StatusOr<bool> HloControlFlowFlattening::Run(HloModule* module) {
auto call_graph = CallGraph::Build(module);
bool changed = false;
absl::flat_hash_set<HloInstruction*> removed;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (removed.contains(instruction)) {
// Skip the instruction if it is already removed.
continue;
}
if (flatten_while_loop_ && instruction->opcode() == HloOpcode::kWhile) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(FlattenWhileLoop(instruction, *call_graph));
changed = true;
} else if (remove_infeed_outfeed_ &&
instruction->opcode() == HloOpcode::kInfeed) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemoveInfeed(instruction));
changed = true;
} else if (remove_infeed_outfeed_ &&
instruction->opcode() == HloOpcode::kOutfeed) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemoveOutfeed(instruction));
changed = true;
} else if (instruction->opcode() == HloOpcode::kSendDone) {
auto send_done_instruction =
DynCast<HloSendDoneInstruction>(instruction);
CHECK(send_done_instruction);
if (remove_comm_ || (remove_host_transfer_ &&
send_done_instruction->is_host_transfer())) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemoveSendDone(instruction, &removed));
changed = true;
}
} else if (instruction->opcode() == HloOpcode::kRecvDone) {
auto recv_done_instruction =
DynCast<HloRecvDoneInstruction>(instruction);
CHECK(recv_done_instruction);
if (remove_comm_ || (remove_host_transfer_ &&
recv_done_instruction->is_host_transfer())) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemoveRecvDone(instruction, &removed));
changed = true;
}
} else if (remove_comm_ && IsCollective(instruction) &&
!instruction->parent()->IsFusionComputation()) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemoveCollective(instruction));
changed = true;
} else if (remove_comm_ &&
(instruction->opcode() == HloOpcode::kPartitionId ||
instruction->opcode() == HloOpcode::kReplicaId)) {
VLOG(1) << "Remove " << instruction->name();
TF_RETURN_IF_ERROR(RemovePartitionOrReplicaId(instruction));
}
}
}
HloDCE hlo_dce;
TF_ASSIGN_OR_RETURN(bool dce_changed, hlo_dce.Run(module));
changed |= dce_changed;
// Fix the schedule if the module was scheduled.
if (changed && module->has_schedule()) {
TF_RETURN_IF_ERROR(module->schedule().Update());
}
XLA_VLOG_LINES(3, module->ToString());
return changed;
}
} // namespace xla