blob: fcbd9e760c65f9050a3bc8fbf992726d7cf2f30c [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
#include <stdint.h>
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
namespace {
bool IsProfitableOperand(HloInstruction* instr) {
// kConstant instruction will not have memory reads, so it won't be a profit
// source. Skip them.
if (instr->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(instr->shape())) {
return false;
}
return true;
}
bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) {
// If we're fusing fusions only do it if the fusion kind matches. Loop fusions
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
// together too (the result being an input fusion) if we find cases where this
// improves things. Also disable fusing standalone input-fusible reduces into
// loop fusions.
CHECK(instr1->opcode() == HloOpcode::kFusion);
if ((instr2->opcode() == HloOpcode::kFusion &&
instr1->fusion_kind() != instr2->fusion_kind()) ||
(IsReductionFromOrToContiguousDimensions(*instr2) &&
instr1->IsLoopFusion())) {
return false;
}
// The emitter only supports in-place DUS for fusions with a single DUS at the
// root. Don't sibling fuse DUS for now.
// TODO(b/119178699): Multi-output fusing DUS can improve performance if we
// share the input and output buffers and add support to the emitter.
if (instr1->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice ||
(instr2->opcode() == HloOpcode::kFusion &&
instr2->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice)) {
return false;
}
// Do this check last, as it may be expensive.
return !FusionWouldBeTooLarge(*instr1, *instr2);
}
// We prefer multi-output fusions over other fusions over unfused ops, because
// we want to preserve fusion opportunities if possible.
int FusionPriority(const HloInstruction* instr) {
if (instr->IsMultiOutputFusion()) {
return 2;
}
if (instr->opcode() == HloOpcode::kFusion) {
return 1;
}
return 0;
}
HloInstruction* SelectPreferredFusionCandidate(
const std::vector<HloInstruction*> candidates) {
if (candidates.empty()) {
return nullptr;
}
return *std::max_element(
candidates.begin(), candidates.end(),
[](const HloInstruction* a, const HloInstruction* b) {
return FusionPriority(a) < FusionPriority(b);
});
}
std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
const HloInstruction* producer, const HloReachabilityMap& reachability) {
std::vector<HloInstruction*> fusion_candidates;
for (HloInstruction* consumer : producer->users()) {
VLOG(3) << "Looking at producer " << producer->name()
<< " and its consumer " << consumer->name();
if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
VLOG(3) << "Consumer " << consumer->name()
<< " is not eligible as multi-output fusion root.";
continue;
}
if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
VLOG(3) << producer->name() << " and " << consumer->name()
<< " are not fusible.";
continue;
}
// Do not fuse a producer if the other operands of the fusion are
// reachable from the producer, this would create a cycle.
auto operand_reachable_from_producer = [&](const HloInstruction* operand) {
// If a get-tuple-element instruction is not in the reachability
// map, it has been created by fusion in this pass. Simply move
// on to its operand, which is in the reachability map.
if (!reachability.IsPresent(operand) &&
operand->opcode() == HloOpcode::kGetTupleElement) {
operand = operand->operand(0);
}
CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer))
<< "Reachability map is incomplete. This should never "
"happen.";
return producer != operand && reachability.IsReachable(producer, operand);
};
if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) {
VLOG(3) << producer->name() << " would introduce a cycle when fused.";
continue;
}
if (FusionWouldBeTooLarge(*producer, *consumer)) {
VLOG(3) << producer->name() << " and " << consumer->name()
<< " would be too large of a fusion.";
continue;
}
fusion_candidates.push_back(consumer);
}
return fusion_candidates;
}
bool IsSiblingFusionCandidate(const HloInstruction* instr) {
if (instr->user_count() == 0) {
return false;
}
if (!IsFusibleAsMultiOutputFusionRoot(*instr)) {
return false;
}
// Check if the users of multioutput fusion is not a get-tuple-element.
// If this is the case, we bail out because the transformation assumes
// the users are get-tuple-element.
if (instr->IsMultiOutputFusion()) {
for (auto user : instr->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
return false;
}
}
}
return true;
}
} // namespace
void GpuMultiOutputFusion::RecomputeReachability() {
reachability_ = HloReachabilityMap::Build(computation_);
}
bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
if (!IsProfitableOperand(parent)) {
return false;
}
bool changed = false;
std::vector<HloInstruction*> siblings = parent->users();
// Sort the siblings such that multi-output fusion ops occur first, followed
// by fusion ops, followed by unfused ops.
absl::c_stable_sort(siblings,
[](const HloInstruction* a, const HloInstruction* b) {
return FusionPriority(a) > FusionPriority(b);
});
for (auto i = siblings.begin(); i != siblings.end();) {
VLOG(3) << "Considering " << (*i)->name();
if ((*i)->opcode() != HloOpcode::kFusion || !IsSiblingFusionCandidate(*i)) {
++i;
continue;
}
for (auto j = i + 1; j != siblings.end();) {
VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
!ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
!LegalToFuse(*i, *j)) {
++j;
continue;
}
if (!ConsumeFuel(name(), [&] {
return absl::StrFormat("Not fusing %s and %s.", (*i)->name(),
(*j)->name());
})) {
continue;
}
VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
HloInstruction* remaining = *i;
HloInstruction* fused = *j;
if (fused->opcode() == HloOpcode::kFusion) {
remaining->MergeFusionInstructionIntoMultiOutput(fused);
} else {
remaining->FuseInstructionIntoMultiOutput(fused);
CHECK_EQ(0, fused->user_count());
TF_CHECK_OK(computation_->RemoveInstruction(fused));
}
changed = true;
siblings.erase(j);
RecomputeReachability();
}
++i;
}
return changed;
}
bool GpuMultiOutputFusion::DoMultiOutputFusion() {
bool changed = false;
RecomputeReachability();
std::vector<HloInstruction*> defs_before_uses =
computation_->MakeInstructionPostOrder();
while (!defs_before_uses.empty()) {
// Traverse the HLO in uses-before-defs order by removing instruction from
// the back of the vector.
HloInstruction* producer = defs_before_uses.back();
defs_before_uses.pop_back();
// Never multi-output fuse constants. To the extent that we want to fuse
// constants, that should be handled by the regular fusion pass.
if (producer->opcode() == HloOpcode::kConstant) {
VLOG(3) << producer->name() << " is a constant.";
continue;
}
// First, fuse the consumer ops of the current op, which are siblings.
if (FuseSiblings(/*parent=*/producer)) {
changed = true;
}
// Second, perform producer-consumer multi-output fusion. This order will
// ensure that all get-tuple-element ops inserted as a by-product of
// multi-output fusion will occur before the current op in the order of
// traversal, and hence, not get into the way of subsequent fusion attempts.
const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
producer, *reachability_);
auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
if (consumer_for_fusion == nullptr) {
continue;
}
if (!ConsumeFuel(name(), [&] {
return absl::StrFormat("Not fusing %s and %s.", producer->name(),
consumer_for_fusion->name());
})) {
continue;
}
changed = true;
if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
<< consumer_for_fusion->name();
if (producer->opcode() == HloOpcode::kFusion) {
consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer);
} else {
consumer_for_fusion->FuseInstructionIntoMultiOutput(producer);
CHECK_EQ(0, producer->user_count());
TF_CHECK_OK(computation_->RemoveInstruction(producer));
}
RecomputeReachability();
continue;
}
HloInstruction* input_fusion =
computation_->AddInstruction(HloInstruction::CreateFusion(
consumer_for_fusion->shape(),
ChooseFusionKind(*producer, *consumer_for_fusion),
consumer_for_fusion));
VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
<< consumer_for_fusion->name() << " into " << input_fusion->name();
TF_CHECK_OK(
computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
if (producer->opcode() == HloOpcode::kFusion) {
input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
} else {
input_fusion->FuseInstructionIntoMultiOutput(producer);
CHECK_EQ(0, producer->user_count());
TF_CHECK_OK(computation_->RemoveInstruction(producer));
}
RecomputeReachability();
}
return changed;
}
StatusOr<bool> GpuMultiOutputFusion::Run(HloModule* module) {
bool changed = false;
for (auto* computation : module->MakeNonfusionComputations()) {
computation_ = computation;
if (DoMultiOutputFusion()) {
changed = true;
}
}
return changed;
}
} // namespace gpu
} // namespace xla