blob: 582e59349e8e0a3c3b3bffd743d86ebf4a05a2f0 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
bool changed = false;
for (auto* computation : module->MakeNonfusionComputations()) {
computation_ = computation;
RecomputeReachability();
candidates_.clear();
candidates_index_.clear();
all_fusion_candidates_.clear();
int64 index = 0;
for (auto it : computation_->MakeInstructionPostOrder()) {
candidates_.emplace_back(it);
InsertOrDie(&candidates_index_, it, index++);
}
// Create the initial candidate list for each Node.
for (auto& node : candidates_) {
HloInstruction* instruction = node.hlo;
int64 instruction_id = get_candidate_id(instruction);
FusionCandidate& instr_node = candidates_[instruction_id];
if (!IsFusible(instruction)) {
continue;
}
all_fusion_candidates_.push_back(instruction);
std::vector<HloInstruction*> candidates;
absl::flat_hash_set<HloInstruction*> candidates_set;
VLOG(10) << "Looking at instruction: " << instruction->name();
for (auto operand : instruction->operands()) {
// Filter out the non-interesting instructions -- they
// will not generate the savings.
if (!IsProfitableOperand(operand)) {
VLOG(10) << "Operand not profitable: " << operand->name();
continue;
}
VLOG(10) << "Operand profitable: " << operand->name();
for (auto user : operand->users()) {
VLOG(10) << "User: " << user->name();
if (user == instruction || !IsFusible(user)) {
VLOG(10) << "User is not fusible, or is the instruction itself: "
<< user->name();
continue;
}
int64 user_id = get_candidate_id(user);
if (is_connected(instruction, user)) {
VLOG(10) << "User is connected: " << user->name();
continue;
}
if (instruction_id < user_id &&
user->opcode() == HloOpcode::kFusion) {
VLOG(10) << "User ID for user: " << user->name() << " is "
<< user_id << " which is higher than " << instruction_id;
continue;
}
if (!LegalToFuse(instruction, user)) {
VLOG(10) << "User not legal to fuse: " << user->name();
continue;
}
if (candidates_set.insert(user).second) {
VLOG(10) << "User added to candidate list: " << user->name();
candidates.push_back(user);
}
}
}
// Iterate over candidates rather than candidates_set to avoid
// nondeterminism.
for (auto candidate : candidates) {
int64 profit = GetProfit(instruction, candidate);
if (profit > 0) {
FusionCandidate& candidate_node =
candidates_[get_candidate_id(candidate)];
instr_node.fusibles.emplace_back(candidate, profit);
candidate_node.fusibles.emplace_back(instruction, profit);
worklist_.emplace(instruction, candidate, profit);
}
}
}
if (Perform()) {
changed = true;
}
}
return changed;
}
HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
HloInstruction* instr2) {
HloInstruction* remaining = instr1;
HloInstruction* fused = instr2;
// Make sure that if only one of the instructions is a fusion, or if only one
// of the instructions is a multi-output fusion, it's what will be fused into.
if (fused->opcode() == HloOpcode::kFusion) {
std::swap(remaining, fused);
}
if (fused->IsMultiOutputFusion()) {
std::swap(remaining, fused);
}
if (fused->opcode() == HloOpcode::kFusion) {
remaining->MergeFusionInstructionIntoMultiOutput(fused);
} else {
remaining->FuseInstructionIntoMultiOutput(fused);
CHECK_EQ(0, fused->user_count());
TF_CHECK_OK(computation()->RemoveInstruction(fused));
}
return remaining;
}
bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
// kConstant instruction will not have memory reads, so it won't be a profit
// source. Skip them.
if (instr->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(instr->shape())) {
return false;
}
// We don't target to fuse producer/consumer instructions -- this should
// be taken care of by the instruction_fusion pass. If instr has only
// one user, it will not have sibling instructions. We won't consider it.
if (instr->user_count() < 2) {
return false;
}
return true;
}
void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
HloInstruction* fusion = instr1;
HloInstruction* fused = instr2;
if (is_fused(instr1)) {
fusion = instr2;
fused = instr1;
}
// Insert the newly created instruction (if any), to candidates_.
for (auto use : fusion->users()) {
if (candidates_index_.find(use) == candidates_index_.end()) {
int64 index = candidates_.size();
candidates_.emplace_back(use);
InsertOrDie(&candidates_index_, use, index++);
}
}
FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)];
FusionCandidate& fused_node = candidates_[get_candidate_id(fused)];
// Update the reachability graph.
UpdateReachability(fusion, fused, all_fusion_candidates_,
[this](HloInstruction* instr) { return is_fused(instr); });
// Update the fusible list for fusion. Variable new_fusibles keeps
// track of the new or changed entries.
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
absl::flat_hash_set<HloInstruction*> in_list;
auto it = fusion_node.fusibles.begin();
while (it != fusion_node.fusibles.end()) {
HloInstruction* instr = it->first;
if (is_fused(instr) || is_connected(fusion, instr)) {
it = fusion_node.fusibles.erase(it);
continue;
}
in_list.insert(instr);
int64 profit = GetProfit(instr, fusion);
if (profit > it->second) {
it->second = profit;
new_fusibles.emplace_back(instr, profit);
}
++it;
}
// Fused_node has been fused into fusion_node. Take the fusion candidates
// (fusibles) from fused_nodes and add them to the fusion_node's. Filter
// out those fusibles that no longer valid (or already in the list).
for (const auto& it : fused_node.fusibles) {
HloInstruction* instr = it.first;
if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
continue;
}
if (in_list.contains(instr)) {
continue;
}
int64 profit = GetProfit(instr, fusion);
fusion_node.fusibles.emplace_back(instr, profit);
new_fusibles.emplace_back(instr, profit);
}
fused_node.fusibles.clear();
// Update the worklist_.
for (auto it : new_fusibles) {
worklist_.emplace(fusion, it.first, it.second);
}
}
bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1,
HloInstruction* instr2) {
if (instr1 == instr2) {
return false;
}
if (instr1->opcode() != HloOpcode::kFusion) {
return false;
}
// Fusing nodes with 0 users makes no sense and the rest of the implementation
// doesn't support it either.
if (instr1->user_count() == 0 || instr2->user_count() == 0) {
return false;
}
// Check if the users of multioutput fusion is not a get-tuple-element.
// If this is the case, we bail out because the transformation assumes
// the users are get-tuple-element.
auto multioutput_user_is_not_gte = [](HloInstruction* instr) {
if (!instr->IsMultiOutputFusion()) {
return false;
}
for (auto user : instr->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
return true;
}
}
return false;
};
if (multioutput_user_is_not_gte(instr1) ||
multioutput_user_is_not_gte(instr2)) {
return false;
}
if (is_connected(instr1, instr2)) {
return false;
}
if (!ShapesCompatibleForFusion(instr1, instr2)) {
return false;
}
return true;
}
void MultiOutputFusion::RecomputeReachability() {
reachability_ = HloReachabilityMap::Build(computation_);
}
void MultiOutputFusion::UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip) {
for (auto instr : instrs_to_update) {
if (skip != nullptr && skip(instr)) {
continue;
}
if (reachability_->IsReachable(instr2, instr) &&
reachability_->IsReachable(instr1, instr)) {
// If a candidate was already reachable by both, no update needed.
continue;
}
if (reachability_->IsReachable(instr2, instr)) {
reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
}
if (reachability_->IsReachable(instr1, instr)) {
reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
}
}
}
bool MultiOutputFusion::Perform() {
int changed = false;
// Pick the top candidate from queue and try to merge.
while (!worklist_.empty()) {
ToBeFused candidate = worklist_.top();
worklist_.pop();
HloInstruction* instr1 = candidate.instr1;
HloInstruction* instr2 = candidate.instr2;
if (is_fused(instr1) || is_fused(instr2)) {
continue;
}
VLOG(1) << "Considering candidate profit_score=" << candidate.score
<< "\n\t\tinstr1 = " << instr1->ToString()
<< "\n\t\tinstr2 = " << instr2->ToString();
if (LegalToFuse(instr1, instr2)) {
if (!ConsumeFuel(name(), [&] {
return absl::StrFormat("Not fusing %s and %s.", instr1->ToString(),
instr2->ToString());
})) {
break;
}
VLOG(1) << "Fuse!";
VLOG(2) << "Before multi_output_fusion:";
VLOG(2) << "instr1: " << instr1->ToString();
VLOG(2) << "\n"
<< instr1->fused_instructions_computation()->ToString(
HloPrintOptions().set_indent_amount(1));
VLOG(2) << "instr2: " << instr2->ToString();
if (instr2->opcode() == HloOpcode::kFusion) {
VLOG(2) << "\n"
<< instr2->fused_instructions_computation()->ToString(
HloPrintOptions().set_indent_amount(1));
}
Update(instr1, instr2);
HloInstruction* ret = Fuse(instr1, instr2);
set_is_fused(ret == instr1 ? instr2 : instr1);
changed = true;
VLOG(2) << "After fusion, \t this: " << ret->name() << "\n"
<< ret->fused_instructions_computation()->ToString(
HloPrintOptions().set_indent_amount(1));
}
}
if (DoProducerConsumerMultiOutputFusion()) {
changed = true;
}
return changed;
}
bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; }
} // namespace xla