blob: 693ccad2d36af0ea8eb95261f9857f62dcdb9004 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <list>
#include <queue>
#include <set>
#include <sstream>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
using absl::StrCat;
std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root_instruction) {
int parameter_count = 0;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
}
// If root_instruction is not specified use the last added instruction.
HloInstruction* root =
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
return absl::WrapUnique(new HloComputation(
name_, parameter_count, &instructions_, root, fusion_instruction_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(NameUniquer::GetSanitizedName(name)),
unique_id_(-1),
root_instruction_(root_instruction),
fusion_instruction_(fusion_instruction),
is_fusion_computation_(fusion_instruction != nullptr) {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
for (auto& instruction : *instructions) {
if (instruction->opcode() == HloOpcode::kParameter) {
int64 param_no = instruction->parameter_number();
CHECK(param_no >= 0 && param_no < parameter_count)
<< "\nERROR: invalid parameter number. Expected [0, "
<< parameter_count << "), got " << param_no;
CHECK(param_instructions_[param_no] == nullptr)
<< "\nERROR: parameter number " << param_no
<< " already allocated in this computation";
param_instructions_[param_no] = instruction.get();
}
root_found |= instruction.get() == root_instruction_;
AddInstructionInternal(std::move(instruction));
}
CHECK(root_found)
<< "\nERROR: root instruction is not present in computation.";
}
HloComputation::~HloComputation() {
if (fusion_instruction_ != nullptr) {
CHECK(fusion_instruction_->fused_instructions_computation() == this);
fusion_instruction_->ClearCalledComputations();
fusion_instruction_ = nullptr;
}
}
HloInstruction* HloComputation::AddInstruction(
std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
CHECK(instruction->opcode() != HloOpcode::kParameter)
<< "Parameter instructions cannot be added to a computation after "
<< "it has been built";
if (!new_name.empty()) {
instruction->SetAndSanitizeName(new_name);
}
return AddInstructionInternal(std::move(instruction));
}
HloInstruction* HloComputation::AddInstructionInternal(
std::unique_ptr<HloInstruction> instruction) {
if (parent() != nullptr) {
instruction->UniquifyName(&parent()->instruction_name_uniquer());
instruction->SetUniqueId(parent()->NewUniqueInstructionId());
}
instruction->set_parent(this);
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
instructions_.insert(instructions_.end(), std::move(instruction));
return pinst;
}
HloInstruction* HloComputation::AddParameter(
std::unique_ptr<HloInstruction> instruction) {
CHECK(instruction->opcode() == HloOpcode::kParameter);
CHECK(IsFusionComputation());
CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
instruction->set_parent(this);
param_instructions_.push_back(instruction.get());
AddInstructionInternal(std::move(instruction));
return instructions_.back().get();
}
HloInstruction* HloComputation::AddEntryComputationParameter(
std::unique_ptr<HloInstruction> instruction) {
CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
CHECK_EQ(instruction->parameter_number(), num_parameters());
CHECK(parent()->entry_computation() == this);
HloModuleConfig config = parent()->config();
config.mutable_entry_computation_layout()->add_parameter_layout(
ShapeLayout(instruction->shape()));
parent()->set_config(config);
instruction->set_parent(this);
param_instructions_.push_back(instruction.get());
AddInstructionInternal(std::move(instruction));
return instructions_.back().get();
}
Status HloComputation::ReplaceEntryComputationParameter(
int64 param_no, HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> instruction) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
CHECK(parent()->entry_computation() == this);
HloModuleConfig config = parent()->config();
*config.mutable_entry_computation_layout()->mutable_parameter_layout(
param_no) = ShapeLayout(instruction->shape());
parent()->set_config(config);
instruction->set_parent(this);
param_instructions_[param_no] = instruction.get();
AddInstructionInternal(std::move(instruction));
return ForceRemoveInstruction(old_instruction);
}
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
CHECK(IsFusionComputation());
HloInstruction* param_instruction = param_instructions_[param_no];
auto param_instruction_iterator = param_instructions_.begin() + param_no;
param_instructions_.erase(param_instruction_iterator);
// Throw removed fused parameter instruction away.
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
param_no++;
}
return Status::OK();
}
HloInstruction* HloComputation::ReplaceParameter(
int64 param_no,
std::unique_ptr<HloInstruction> instruction) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
CHECK(instruction->opcode() == HloOpcode::kParameter);
CHECK(IsFusionComputation());
CHECK_EQ(fusion_instruction_->operand_count(), param_instructions_.size());
instruction->set_parent(this);
HloInstruction* new_instruction = AddInstructionInternal(std::move(instruction));
HloInstruction* old_instruction = param_instructions_[param_no];
CHECK(old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction).ok());
param_instructions_[param_no] = new_instruction;
CHECK(RemoveInstruction(old_instruction).ok());
return new_instruction;
}
Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
}
Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
}
Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
CHECK(allow_non_fusion || IsFusionComputation());
int64 removed = 0;
for (int64 i = 0; i < param_instructions_.size(); ++i) {
HloInstruction* param_instruction = param_instructions_[i];
if (param_instruction->user_count() == 0 &&
param_instruction != root_instruction()) {
TF_RETURN_IF_ERROR(
RemoveInstructionImpl(param_instruction, allow_non_fusion));
++removed;
continue;
}
if (removed > 0) {
const int64 param_no = i - removed;
HloInstruction* new_instr = AddInstructionInternal(
HloInstruction::CreateParameter(param_no, param_instruction->shape(),
StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(
RemoveInstructionImpl(param_instruction, allow_non_fusion));
}
}
param_instructions_.resize(param_instructions_.size() - removed);
return Status::OK();
}
bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
// example, to avert interference due to buffer aliasing).
if (!instruction->control_predecessors().empty() ||
!instruction->control_successors().empty()) {
return false;
}
if (instruction->opcode() == HloOpcode::kParameter &&
!IsFusionComputation()) {
return false;
}
return true;
}
bool HloComputation::HasSideEffect() const {
for (auto* instruction : instructions()) {
if (instruction->HasSideEffect()) {
return true;
}
}
return false;
}
bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
return inst->IsMarkedAsDead();
}
Status HloComputation::RemoveInstructionAndUnusedOperands(
HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
TF_RET_CHECK(root_instruction() != instruction);
TF_RET_CHECK(instruction->user_count() == 0);
TF_RET_CHECK(IsSafelyRemovable(instruction))
<< "Cannot remove instruction: " << instruction->ToString();
absl::flat_hash_set<HloInstruction*> removed;
std::queue<HloInstruction*> worklist;
worklist.push(instruction);
while (!worklist.empty()) {
HloInstruction* item = worklist.front();
worklist.pop();
if (removed.contains(item) || item->user_count() != 0 ||
item == root_instruction() || !IsSafelyRemovable(item) ||
(item->HasSideEffect() && item != instruction)) {
continue;
}
for (int i = 0; i < item->operand_count(); ++i) {
worklist.push(item->mutable_operand(i));
}
if (cleanup) {
cleanup(item);
}
TF_RETURN_IF_ERROR(RemoveInstruction(item));
removed.insert(item);
}
return Status::OK();
}
Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
}
Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
}
Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
bool ignore_safety_check) {
VLOG(2) << "Removing instruction " << instruction->name()
<< " from computation " << name();
TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
<< "cannot remove instruction: " << instruction->ToString();
TF_RET_CHECK(root_instruction() != instruction)
<< "cannot remove root instruction " << instruction->name();
TF_RET_CHECK(instruction->user_count() == 0)
<< "instruction " << instruction->name()
<< " has users and cannot be removed";
TF_RET_CHECK(instruction->control_predecessors().empty())
<< "instruction " << instruction->name()
<< " has control predecessors and cannot be removed";
TF_RET_CHECK(instruction->control_successors().empty())
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
auto inst_it = instruction_iterators_.find(instruction);
TF_RET_CHECK(inst_it != instruction_iterators_.end());
(*inst_it->second)->set_parent(nullptr);
to_be_deleted_.emplace_back(inst_it->second->release());
to_be_deleted_.back()->DetachFromOperandsAndUsers();
// Clear all operands to avoid Null operands.
to_be_deleted_.back()->RemoveAllOperands();
to_be_deleted_.back()->ClearCalledComputations();
to_be_deleted_.back()->MarkAsDead();
instructions_.erase(inst_it->second);
instruction_iterators_.erase(inst_it);
return Status::OK();
}
void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
bool accept_different_shape) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
if (!IsFusionComputation() && !accept_different_shape) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape() << " is incompatible with "
<< root_instruction_->shape();
}
bool root_found = false;
for (auto& instruction : instructions_) {
if (new_root_instruction == instruction.get()) {
root_found = true;
break;
}
}
DCHECK(root_found);
if (parent() && parent()->has_entry_computation() &&
parent()->entry_computation() == this) {
if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
root_instruction_->shape())) {
// Rebuild input output alias config now that we have a new output shape.
parent()->input_output_alias_config() =
HloInputOutputAliasConfig(new_root_instruction->shape());
}
}
root_instruction_ = new_root_instruction;
}
namespace {
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(HloComputation* computation,
absl::flat_hash_set<HloComputation*>* visited,
std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
instruction->called_computations()) {
ComputeComputationPostOrder(called_computation, visited, post_order);
}
}
post_order->push_back(computation);
}
}
} // namespace
void HloComputation::ComputeInstructionPostOrder(
const HloComputation::ChannelDependencyGroup& channel_dependency_group,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
const auto current = dfs_stack.back();
CHECK_EQ(current->parent(), this)
<< "Instruction " << current->name()
<< " is not in the current computation (" << name() << ").";
auto it = visited->find(current);
if (it != visited->end()) {
if (it->second == kVisited) {
// Already visited.
dfs_stack.pop_back();
continue;
}
// Visit this node.
CHECK_EQ(kVisiting, it->second);
dfs_stack.pop_back();
post_order->push_back(current);
it->second = kVisited;
continue;
}
visited->insert({current, kVisiting});
const auto get_channel_id =
[](HloInstruction* inst) -> absl::optional<int64> {
switch (inst->opcode()) {
case HloOpcode::kRecvDone:
return inst->channel_id();
case HloOpcode::kAllReduce:
return inst->channel_id();
default:
return absl::nullopt;
}
};
// When adding a predecessor to the dfs_stack, we need to also add its
// associated channel dependencies.
const auto add_dfs_stack = [&](HloInstruction* inst) {
auto channel_id = get_channel_id(inst);
if (channel_id && channel_dependency_group.count(*channel_id)) {
auto it = channel_dependency_group.find(*channel_id);
for (HloInstruction* cinst : it->second) {
dfs_stack.emplace_back(cinst);
}
} else {
dfs_stack.emplace_back(inst);
}
};
const auto add_predecessors = [&](HloInstruction* inst) {
// Add the operands to the stack in reverse order so the first operand is
// processed first. This will produce a more natural ordering and a nicer
// result for things like HLO stringification.
const auto& operands = inst->operands();
for (int64 i = operands.size() - 1; i >= 0; --i) {
add_dfs_stack(operands[i]);
}
for (HloInstruction* op : inst->control_predecessors()) {
add_dfs_stack(op);
}
};
// If the current instruction is a channel instruction, add the dependencies
// from all associated instructions of the channel.
auto channel_id = get_channel_id(current);
if (channel_id && channel_dependency_group.count(*channel_id)) {
auto it = channel_dependency_group.find(*channel_id);
for (HloInstruction* cinst : it->second) {
add_predecessors(cinst);
}
} else {
add_predecessors(current);
}
}
}
HloComputation::ChannelDependencyGroup
HloComputation::ComputeChannelDependencies() const {
ChannelDependencyGroup channel_dependency_group;
for (const auto& instruction : instructions_) {
switch (instruction->opcode()) {
case HloOpcode::kSend:
case HloOpcode::kRecvDone:
case HloOpcode::kAllReduce: {
auto channel_id = instruction->channel_id();
if (channel_id) {
channel_dependency_group[channel_id.value()].push_back(
instruction.get());
}
break;
}
default:
break;
}
}
return channel_dependency_group;
}
static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
return user->opcode() == HloOpcode::kTrace;
});
}
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
auto channel_dependency_group = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
absl::flat_hash_map<HloInstruction*, VisitState> visited;
visited.reserve(instruction_count());
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
// instructions to the post order at the end (necessarily they have no
// users).
trace_instructions.push_back(instruction.get());
} else if (HasOnlyTraceUsers(instruction.get())) {
ComputeInstructionPostOrder(channel_dependency_group, &post_order,
instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
absl::flat_hash_set<HloComputation*> visited;
std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
// construction.
//
// TODO(b/78350259): This violates const-correctness, since while the original
// computation is not returned, we still retrieve non-const computations from
// a const one. Consider also avoiding const for HloComputation, or review XLA
// for const-correctness of non-HloInstruction* types like this.
ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
&post_order);
// We don't want to include this computation in the post order.
CHECK_EQ(this, post_order.back());
post_order.pop_back();
return post_order;
}
string HloComputation::ToString(const HloPrintOptions& options) const {
return ToString(options, MakeInstructionPostOrder());
}
string HloComputation::ToString(
const HloPrintOptions& options,
absl::Span<const HloInstruction* const> instruction_order) const {
CHECK_EQ(instruction_order.size(), instruction_count());
const string tab(2 * options.indent_amount(), ' ');
std::ostringstream s;
s << tab;
if (!options.is_in_nested_computation()) {
if (options.print_percent()) {
s << "%";
}
if (options.print_ids()) {
// Exclude entry computation's name because it includes and leads to
// non-deterministic fingerprint.
s << PrintName(name(), options.print_ids()) << " ";
}
}
if (options.print_program_shape()) {
s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
<< " ";
}
s << "{\n";
// There are instructions which are required to be printed. Additionally, we
// print some instructions before and after required ones. The resulting
// output has the following format.
//
// computation {
// ...
// additional_instructions
// required_instructions
// additional_instructions
// ...
// additional_instructions
// required_instructions
// additional_instructions
// ...
// }
std::set<int> instructions_to_print;
{
// Find all the instructions that should be printed.
auto add_instruction = [&instructions_to_print,
&instruction_order](int index) {
if (index < 0 || index >= instruction_order.size()) {
return;
}
instructions_to_print.insert(index);
};
auto add_instructions_arround = [&add_instruction, &options](int index) {
for (int i = index - options.leading_and_trailing_instructions_number();
i <= index + options.leading_and_trailing_instructions_number();
++i) {
add_instruction(i);
}
};
for (int i = 0; i < instruction_order.size(); ++i) {
const HloInstruction* instruction = instruction_order[i];
CHECK_EQ(this, instruction->parent());
if (options.print_instruction(instruction)) {
add_instructions_arround(i);
}
}
}
{
// Print the instructions in this computation.
HloPrintOptions new_options = options;
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
const string new_tab(2 * new_options.indent_amount(), ' ');
CanonicalNameMap name_map;
bool print_prev = true;
for (int index = 0; index < instruction_order.size(); ++index) {
const HloInstruction* instruction = instruction_order[index];
if (instructions_to_print.find(index) != instructions_to_print.end()) {
s << new_options.format_instruction(
instruction,
instruction->ToStringWithCanonicalNameMap(new_options,
&name_map),
new_options.indent_amount(), instruction == root_instruction_)
<< "\n";
print_prev = true;
} else if (print_prev) {
s << new_tab << "...\n";
print_prev = false;
}
}
}
s << tab << "}";
return s.str();
}
HloComputationProto HloComputation::ToProto() const {
HloComputationProto proto;
CHECK(unique_id_ != -1)
<< "This computation does not have a valid id. Please make sure the "
"computation is inside a module before dumping it.";
proto.set_id(unique_id_);
proto.set_name(name_);
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
HloInstructionProto instruction_proto = instruction->ToProto();
proto.add_instructions()->Swap(&instruction_proto);
}
proto.set_root_id(root_instruction()->unique_id());
*proto.mutable_program_shape() = ComputeProgramShape().ToProto();
return proto;
}
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
const HloComputationProto& proto,
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
HloInstruction::CreateFromProto(
instruction_proto, instruction_map, computation_map,
prohibit_empty_literal));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
instruction_map[instruction_proto.id()] = instruction.get();
to_proto_id[instruction.get()] = instruction_proto.id();
instructions.push_back(std::move(instruction));
}
TF_RET_CHECK(proto.root_id() != -1);
TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
HloInstruction* root = instruction_map.at(proto.root_id());
// Sort the instructions in the proto id's order.
absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
const std::unique_ptr<HloInstruction>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
TF_RETURN_IF_ERROR([&]() -> Status {
std::vector<bool> parameters_seen(parameter_count);
int parameters_seen_count = 0;
for (auto& instruction : instructions) {
if (instruction->opcode() == HloOpcode::kParameter) {
int64 param_no = instruction->parameter_number();
TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
<< "Invalid parameter number. Expected [0, " << parameter_count
<< "), got " << param_no;
TF_RET_CHECK(!parameters_seen[param_no])
<< "Parameter number " << param_no
<< " already allocated in this computation";
parameters_seen[param_no] = true;
parameters_seen_count++;
}
}
TF_RET_CHECK(parameters_seen_count == parameter_count)
<< "Not all parameters in range [0, " << parameter_count
<< ") were referenced";
return Status::OK();
}());
auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*fusion_instruction=*/nullptr));
computation->unique_id_ = proto.id();
return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction) {
CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
HloInstruction* root = instructions_to_fuse.front();
TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
if (root == root_instruction()) {
set_root_instruction(fusion_instruction);
}
TF_CHECK_OK(RemoveInstruction(root));
for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
HloInstruction* instruction = instructions_to_fuse[i];
fusion_instruction->FuseInstruction(instruction);
if (instruction->user_count() == 0) {
TF_CHECK_OK(RemoveInstruction(instruction));
}
}
}
HloInstruction* HloComputation::CreateFusionInstruction(
absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind) {
HloInstruction* root = instructions_to_fuse.front();
HloInstruction* fusion_instruction = AddInstruction(
HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
return fusion_instruction;
}
StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
HloInstruction* instruction, ShapeIndex* index,
const std::function<
HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* computation)>& copy_leaf) {
if (instruction->shape().IsTuple()) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
i++) {
HloInstruction* gte =
AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(instruction->shape(), i),
instruction, i));
index->push_back(i);
TF_ASSIGN_OR_RETURN(HloInstruction * element,
DeepCopyHelper(gte, index, copy_leaf));
elements.push_back(element);
index->pop_back();
}
return AddInstruction(HloInstruction::CreateTuple(elements));
}
if (instruction->shape().IsToken()) {
// Tokens have no on-device representation and cannot be copied. Pass
// through transparently.
return instruction;
}
// Array shape.
TF_RET_CHECK(instruction->shape().IsArray());
return copy_leaf(instruction, *index, this);
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
ShapeTree<HloInstruction*>* copies_added) {
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
instruction->name(), name());
}
if (indices_to_copy != nullptr &&
!ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
return FailedPrecondition(
"Can't deep copy instruction %s: given shape tree of indices to copy "
"has incompatible shapes: %s vs. %s",
instruction->name(), ShapeUtil::HumanString(instruction->shape()),
ShapeUtil::HumanString(indices_to_copy->shape()));
}
ShapeIndex index;
auto copy_leaf = [indices_to_copy, copies_added](
HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* computation) {
if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
HloInstruction* copy = computation->AddInstruction(
HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
if (copies_added != nullptr) {
*copies_added->mutable_element(leaf_index) = copy;
}
return copy;
}
// Elements which are not to be copied are passed through
// transparently.
return leaf;
};
return DeepCopyHelper(instruction, &index, copy_leaf);
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
HloInstruction* instruction,
const std::function<
HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* computation)>& copy_leaf) {
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
instruction->name(), name());
}
ShapeIndex index;
return DeepCopyHelper(instruction, &index, copy_leaf);
}
ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
ProgramShape program_shape;
for (auto* param_instruction : param_instructions_) {
*program_shape.add_parameters() = param_instruction->shape();
*program_shape.add_parameter_names() =
PrintName(param_instruction->name(), include_ids);
}
*program_shape.mutable_result() = root_instruction_->shape();
return program_shape;
}
bool HloComputation::EqualInternal(const HloComputation& other,
bool is_layout_sensitive,
bool ignore_channel_id_values) const {
if (this == &other) {
return true;
}
absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
visited;
std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
worklist.push_back({root_instruction(), other.root_instruction()});
while (!worklist.empty()) {
auto pair = worklist.back();
worklist.pop_back();
if (visited.contains(pair)) {
continue;
}
visited.emplace(pair);
// TODO(b/123082518): Avoid recursively invoking Equal because it may
// cause a stack overflow with deeply nested subcomputations.
auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
return true;
};
auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
return a->EqualInternal(*b, is_layout_sensitive,
ignore_channel_id_values);
};
bool identical_ignoring_operands =
ignore_channel_id_values
? pair.first->IdenticalIgnoringChannelIdValues(
*pair.second, operands_eq, comp_eq, is_layout_sensitive)
: pair.first->Identical(*pair.second, operands_eq, comp_eq,
is_layout_sensitive);
if (!identical_ignoring_operands) {
return false;
}
for (size_t i = 0; i < pair.first->operands().size(); ++i) {
worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
}
}
return true;
}
Status HloComputation::ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
return ReplaceInstruction(old_instruction,
AddInstruction(std::move(new_instruction)));
}
Status HloComputation::ReplaceWithNewEntryComputationParameter(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
std::move(new_instruction)));
}
Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
TF_RET_CHECK(
ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
<< ShapeUtil::HumanString(old_instruction->shape()) << " vs "
<< ShapeUtil::HumanString(new_instruction->shape());
return ReplaceInstructionWithDifferentShape(old_instruction, new_instruction);
}
Status HloComputation::ReplaceInstructionWithDifferentShape(
HloInstruction* old_instruction, HloInstruction* new_instruction) {
VLOG(10) << "transformed " << old_instruction->ToString() << " to "
<< new_instruction->ToString();
// Try to add metadata for HLO instructions that are created to replace
// existing HLO instructions (e.g. during optimizations). The assumption is
// that the old instruction and the new instruction would perform the same
// function, and that they would be correlated to the same TF op. This might
// not always be correct since HLO optimizations can cross TF op boundaries.
// But still this seems to be better than nothing.
bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
!old_instruction->metadata().op_name().empty();
bool overwrite_pass_id =
new_instruction->metadata().op_name().empty() &&
new_instruction->metadata().logical_creation_pass_id() == 0 &&
old_instruction->metadata().logical_creation_pass_id() != 0;
if (overwrite_op_name || overwrite_pass_id) {
new_instruction->set_metadata(old_instruction->metadata());
}
if (new_instruction->frontend_attributes().map().empty()) {
new_instruction->set_frontend_attributes(
old_instruction->frontend_attributes());
}
// Like the metadata above, if the user didn't specify any sharding
// information on the new instruction we should copy the old sharding
// information (if any).
if (!new_instruction->has_sharding()) {
new_instruction->set_sharding(old_instruction->sharding_ptr());
}
TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWithDifferentShape(
new_instruction));
return RemoveInstructionAndUnusedOperands(old_instruction);
}
std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
std::vector<HloInstruction*> unreachable_roots;
for (auto* instruction : instructions()) {
if (instruction->user_count() == 0 &&
instruction->control_successors().empty() &&
instruction != root_instruction()) {
unreachable_roots.push_back(instruction);
}
}
VLOG(3) << "Unreachable roots:"
<< absl::StrJoin(unreachable_roots, "\n\t",
[](string* out, const HloInstruction* hlo) {
absl::StrAppend(out, hlo->ToString());
});
return unreachable_roots;
}
Status HloComputation::AcceptWithOperandOrder(
DfsHloVisitor* visitor,
const HloInstruction::CompareFunction& operand_order) const {
// Visit unreachable roots. Beware that the visitor might delete the currently
// visited root, which would invalidate iterators if the unreachable roots
// weren't computed ahead of time.
for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RETURN_IF_ERROR(
root->AcceptWithOperandOrder(visitor, operand_order,
/*call_finish_visit=*/false));
}
// Visit the computation root instruction last.
return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
/*call_finish_visit=*/true);
}
std::unique_ptr<HloComputation> HloComputation::Clone(
const string& suffix, HloCloneContext* context) {
return CloneWithReplacements(
/*replacements=*/absl::flat_hash_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
/*extra_parameters=*/{}, context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
HloCloneContext* context, const string& suffix) {
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
replacements.emplace(std::move(r1));
return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
HloCloneContext* context, const string& suffix) {
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
replacements.emplace(std::move(r1));
replacements.emplace(std::move(r2));
return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
HloCloneContext* context, const string& suffix) {
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
replacements.emplace(std::move(r1));
replacements.emplace(std::move(r2));
replacements.emplace(std::move(r3));
return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
absl::Span<const HloInstruction* const> extra_parameters,
HloCloneContext* context, const string& suffix,
const HloInstruction* new_root) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
context = context_ptr.get();
}
if (new_root == nullptr) {
new_root = root_instruction();
}
// Look up instr in the replacements map, and return either the replacement,
// or instr, if the replacement isn't present.
//
// Note: This can return null, indicating that instr should not be present in
// the new computation.
auto replace = [&](const HloInstruction* instr) {
auto it = replacements.find(instr);
return it != replacements.end() ? it->second.get() : instr;
};
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
// We want to do a postorder walk over [replace(i) for i in instructions_].
// We can't reuse MakeInstructionPostOrder() for this, because that will
// generate a postorder of plain instructions_, and our replacements may
// change the postorder!
//
// The postorder we want here is simpler than what MakeInstructionPostOrder()
// does -- we only care about operand dependencies -- so let's just do it
// ourselves.
std::vector<const HloInstruction*> postorder;
absl::flat_hash_map<const HloInstruction*, VisitState> visited;
for (const auto& instr : instructions_) {
std::vector<const HloInstruction*> dfs_stack;
const HloInstruction* new_instr = replace(instr.get());
if (!new_instr) {
continue;
}
dfs_stack.push_back(new_instr);
while (!dfs_stack.empty()) {
auto* cur = dfs_stack.back();
auto it = visited.find(cur);
if (it != visited.end()) {
dfs_stack.pop_back();
if (it->second == kVisited) {
continue;
}
CHECK_EQ(it->second, kVisiting);
postorder.push_back(cur);
it->second = kVisited;
continue;
}
visited.insert({cur, kVisiting});
for (HloInstruction* operand : cur->operands()) {
const HloInstruction* new_operand = replace(operand);
if (new_operand) {
dfs_stack.emplace_back(new_operand);
}
}
}
}
std::vector<std::unique_ptr<HloInstruction>> instructions;
// First add the extra parameters to 'instructions'.
for (const auto& instr : extra_parameters) {
CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
<< "Only parameter instructions are allowed in 'extra_parameters'";
instructions.emplace_back(instr->Clone());
}
for (auto instr : postorder) {
std::vector<HloInstruction*> new_operands;
for (auto operand : instr->operands()) {
auto replaced_operand = replace(operand);
CHECK_NE(replaced_operand, nullptr)
<< "replacements map tried to eliminate a used instruction "
<< operand->ToString() << ", used by " << instr->ToString();
new_operands.push_back(context->GetInstruction(replaced_operand));
}
std::unique_ptr<HloInstruction> new_instr =
instr->CloneWithNewOperands(instr->shape(), new_operands, context);
if (instr->opcode() == HloOpcode::kParameter &&
instr->parameter_replicated_at_leaf_buffers().has_value()) {
new_instr->set_parameter_replicated_at_leaf_buffers(
instr->parameter_replicated_at_leaf_buffers().value());
}
instructions.push_back(std::move(new_instr));
}
Builder builder(name() + "." + suffix);
for (auto& instr : instructions) {
builder.AddInstruction(std::move(instr));
}
auto result = builder.Build(
/*root_instruction=*/context->GetInstruction(replace(new_root)));
// Clone control dependencies.
for (auto instr : postorder) {
HloInstruction* new_instr = context->GetInstruction(instr);
for (auto successor : instr->control_successors()) {
auto replaced_successor = replace(successor);
// successor may not have been remapped, because it might have been
// removed by the replacements map.
if (replaced_successor != nullptr) {
TF_CHECK_OK(new_instr->AddControlDependencyTo(
context->GetInstruction(replaced_successor)));
}
}
}
context->MapComputation(this, result.get());
return result;
}
void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
name_ = name_uniquer->GetUniqueName(name_);
}
HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
auto instructions_in_computation = instructions();
auto it = absl::c_find_if(
instructions_in_computation,
[&](HloInstruction* instr) { return instr->name() == name; });
return it == instructions_in_computation.end() ? nullptr : *it;
}
bool HloComputation::IsEntryComputation() const {
return parent()->entry_computation() == this;
}
} // namespace xla