blob: 7b02a997f9c898baba598cc1224675c242abcb77 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include <algorithm>
#include <iterator>
#include <stack>
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
namespace gpu {
namespace {
// The amount of shared memory a CUDA kernel can use.
//
// Stay on the conservative side, this is smaller than full 64kB, but allows
// some extra space for cache.
int64_t kSharedMemoryBudgetInBytes = 40000;
bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
if (instr.opcode() == HloOpcode::kReduce &&
!IsReductionFromOrToContiguousDimensions(instr)) {
return true;
}
// Avoid fusing reduce-window when stride is less than window size to minimize
// the number of reads of the same elements.
if (instr.opcode() == HloOpcode::kReduceWindow) {
for (const auto& dim : instr.window().dimensions()) {
if (dim.size() > dim.stride()) {
return true;
}
}
}
return false;
}
} // namespace
bool IsPhysicallyTransposing(const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kFusion) {
for (const HloInstruction* fused_instr : instr.fused_instructions()) {
if (IsPhysicallyTransposing(*fused_instr)) {
return true;
}
}
}
// A fusion iterates over its output in physically-contiguous order. This
// applies "upwards" to operands. Only an operator that changes an operand's
// physical layout can create a "bad" memory access pattern.
return instr.opcode() == HloOpcode::kCopy ||
(instr.opcode() == HloOpcode::kTranspose &&
!ShapeUtil::TransposeIsBitcast(instr.operand(0)->shape(),
instr.shape(), instr.dimensions()));
}
bool IsReduceInputFusion(const HloInstruction& instr) {
if (instr.IsMultiOutputFusion()) {
for (const HloInstruction* operand :
instr.fused_expression_root()->operands()) {
if (IsReductionFromOrToContiguousDimensions(*operand)) {
CHECK(instr.IsInputFusion())
<< " Multi-output fusion rooted at reduction-to-vector ops must be "
"of kind kInput: "
<< instr.ToString();
return true;
}
}
} else if (instr.opcode() == HloOpcode::kFusion &&
IsReductionFromOrToContiguousDimensions(
*instr.fused_expression_root())) {
CHECK(instr.IsInputFusion())
<< " Fusion rooted at reduction-to-vector op must be of kind kInput: "
<< instr.ToString();
return true;
}
return false;
}
bool IsInputFusibleReduction(const HloInstruction& instr) {
return IsReduceInputFusion(instr) ||
IsReductionFromOrToContiguousDimensions(instr);
}
const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return &instr;
}
auto fused_expression_root = instr.fused_expression_root();
if (!instr.IsMultiOutputFusion()) {
return fused_expression_root;
}
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
// operand of the fusion root, because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
return inst;
}
}
return fused_expression_root->operands()[0];
}
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
const HloInstruction& instr2) {
// Multi-output fusion kernels share a common parallel loop. The loop
// dimensions are determined by instruction shapes.
auto get_loop_shape = [&](const HloInstruction* element_instr) {
// Special-case reduction-to-vector ops: The loop dimensions are determined
// by the shape of the first operand.
if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
return element_instr->operand(0)->shape();
}
return element_instr->shape();
};
// All shapes of the root tuple of multi-output fusions should agree, i.e. all
// root ops should have equal output shapes. An exception are
// reduction-to-vector ops. Here the input shapes of the reduction (first
// operand shape) and the reduction dimensions need to match.
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
IsReductionFromOrToContiguousDimensions(*instr_2) &&
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
return false;
}
// The elementwise output shapes must be the same (including layout).
return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
get_loop_shape(instr_2));
}
bool IsInputFusibleScatter(const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kScatter ||
(instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
return true;
}
return false;
}
bool IsInputFusible(const HloInstruction& instr) {
// Input fusion only handles non-elemental reduction and scatter operations.
return instr.IsFusible() &&
(IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
}
bool IsLoopFusible(const HloInstruction& instr) {
// Don't fuse get-tuple-element on GPU: We can, but it's slower than not
// fusing. We never generate kernels for unfused GTEs. Instead, if an
// unfused GTE is an input to a kernel (including a fusion kernel), we
// compute the address of the GTE at the top of the kernel. Often we know the
// address of the GTE result statically, so we can do this without chasing any
// pointers.
return instr.IsFusible() &&
((instr.IsElementwise() && instr.operand_count() > 0) ||
instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kBroadcast ||
instr.opcode() == HloOpcode::kConcatenate ||
instr.opcode() == HloOpcode::kDynamicSlice ||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
(instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
instr.opcode() == HloOpcode::kGather ||
instr.opcode() == HloOpcode::kIota ||
instr.opcode() == HloOpcode::kPad ||
(instr.opcode() == HloOpcode::kReduce &&
!IsReductionFromOrToContiguousDimensions(instr) &&
!instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse
// variadic reductions.
instr.opcode() == HloOpcode::kReduceWindow ||
instr.opcode() == HloOpcode::kReshape ||
instr.opcode() == HloOpcode::kReverse ||
instr.opcode() == HloOpcode::kSlice ||
instr.opcode() == HloOpcode::kConstant ||
instr.opcode() == HloOpcode::kTranspose);
}
FusionDecision IsProducerConsumerFusible(const HloInstruction& producer,
const HloInstruction& consumer) {
if (!IsLoopFusible(producer)) {
return "the producer is not loop-fusible";
}
if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
return "the consumer is not input-fusible and not loop-fusible";
}
// Skip multiple output fusion. It's not yet supported.
if (producer.IsMultiOutputFusion()) {
return "the producer is not fusible as it is a multi-output fusion";
}
if (CreatesNestedLoop(producer, consumer)) {
return "the fusion would create a nested loop";
}
// Do not fuse into fusions if the resulting kernel would suffer from
// uncoalesced reads due to a transposed memory access pattern.
if (IsInputFusibleReduction(consumer) && IsPhysicallyTransposing(producer)) {
return "fusing the producer would break read coalescing";
}
// Fuse scalar constants into loop fusion nodes. This reduces the number of
// parameters and makes matching scalar broadcasts easier.
//
// Don't fuse other constants: Unfused constants in GPU land can be
// represented as an external constant (i.e. not emitted in LLVM IR / PTX),
// but fused constants are handled by shrared CPU/GPU code and always emitted
// in the IR/PTX. The external constant representation makes for faster
// compiles and significantly smaller assembly code.
if (producer.opcode() == HloOpcode::kConstant &&
(!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
consumer.opcode() != HloOpcode::kFusion)) {
return "not fusing constant";
}
// Make sure the new fusion obeys the in-place semantics.
return InstructionFusion::ShouldFuseInPlaceOp(&producer, &consumer);
}
bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
const HloInstruction& consumer) {
// Skip multiple output fusion. It's not yet supported.
if (producer.IsMultiOutputFusion()) {
return false;
}
// Allowing multi-output fusions that contain in-place operations makes code
// generation more difficult. For the generated loop to iterate over all
// outputs in parallel, it must find an iteration order that guarantees that
// no loop iteration writes an element of any in-place operand that is read
// or written by any other iteration. For example:
//
// %fused_computation {
// %param_0 = s32[4,4]{1,0} parameter(0)
// ...
// %updated = s32[4,4]{1,0} dynamic-update-slice(
// %param_0, %add, %constant_1, %constant_0)
// %transpose = s32[4,4]{0,1} transpose(%updated), dimensions={1,0}
// ROOT %tuple.5 = tuple(%transpose, %updated)
// }
//
// Iterating 'transpose' and 'updated' in parallel by array index is
// not valid, because an iteration that produces some element of 'transpose'
// will read from an element of 'param_0' that has been overwritten by some
// other iteration (writing to 'updated').
//
// To avoid these problems, we simply ban fusion altogether when the producer
// is in-place. (We can relax this restriction by establishing an explicit
// contract that describes what multi-output fusion scenarios are supported by
// codegen and then changing this check to allow exactly those fusions).
if (!HloDataflowAnalysis::GetInPlaceInputOutputPairs(&producer).empty()) {
return false;
}
if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
return false;
}
if (CreatesNestedLoop(producer, consumer)) {
return false;
}
if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
return false;
}
if (IsPhysicallyTransposing(producer)) {
return false;
}
return true;
}
// Returns shared memory usage for a given instruction in bytes.
static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
// For now we are only fusing reductions.
if (instr.opcode() == HloOpcode::kReduce &&
IsReductionFromOrToContiguousDimensions(instr)) {
ReductionDimensions reduction_info =
GetReductionKindAndContiguousComponents(instr);
int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(
instr.operand(0)->shape().element_type());
int num_variadic =
instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1;
if (reduction_info.is_row_reduction) {
// __shared__[32] is used for row reduction.
return 32 * primitive_size * num_variadic;
} else {
// __shared__[2][32][33] cache is used for column reduction ("2" comes
// from potential x-tiling).
return 2 * 32 * 33 * primitive_size * num_variadic;
}
} else if (instr.opcode() == HloOpcode::kFusion) {
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
sum += SharedMemoryUsageNoCache(*hlo);
}
return sum;
}
// Other fused expressions for now don't need the shared memory budget.
return 0;
}
static int64_t SharedMemoryUsage(const HloInstruction& instr,
FusionInfoCache* cache = nullptr) {
if (!cache) {
return SharedMemoryUsageNoCache(instr);
}
// nb: Users are only expected to call cache.Invalidate() on top-level
// instructions, not instructions inside fusion nodes. Therefore we can only
// cache top-level instructions; it would not be valid to pass the cache to
// SharedMemoryUsageNoCache and use the cache *within* the fusion.
auto it_and_inserted = cache->shared_memory_usage.emplace(&instr, -1);
auto it = it_and_inserted.first;
auto inserted = it_and_inserted.second;
if (inserted) {
it->second = SharedMemoryUsageNoCache(instr);
}
return it->second;
}
// Codegen'ing unnested reductions requires a lot of registers, so a MOF
// combining many of those runs a high risk of spilling.
constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
// Returns the number of unnested reductions in the instruction output.
static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kReduce &&
IsReductionFromOrToContiguousDimensions(instr)) {
return 1;
}
if (instr.opcode() == HloOpcode::kFusion) {
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
sum += NumUnnestedReductionsNoCache(*hlo);
}
return sum;
}
return 0;
}
static int64_t NumUnnestedReductions(const HloInstruction& instr,
FusionInfoCache* cache) {
if (!cache) {
return NumUnnestedReductionsNoCache(instr);
}
// nb: Users are only expected to call cache.Invalidate() on top-level
// instructions, not instructions inside fusion nodes. Therefore we can only
// cache top-level instructions; it would not be valid to pass the cache to
// NumUnnestedReductionsNoCache and use the cache *within* the fusion.
auto it_and_inserted = cache->num_unnested_reductions.emplace(&instr, -1);
auto it = it_and_inserted.first;
auto inserted = it_and_inserted.second;
if (inserted) {
it->second = NumUnnestedReductionsNoCache(instr);
}
return it->second;
}
// This function limits the maximum number of operands to a fusion, and the
// amount of shared memory which can be consumed by the fusion.
//
// There's a cap on how many parameters we can pass to a CUDA kernel, but
// exactly what that limit is hazy, as it depends on (among other things) how
// much GPU constant memory is in use for other purposes.
//
// Moreover, we don't even know at the point that we're running fusion how many
// arguments the CUDA kernel for a fusion node will have: It depends on buffer
// assignment, where we will decide which of the fusion's operands live in XLA's
// big temp buffer versus in other allocations.
//
// As a heuristic, we simply cap the number of fusion operands plus outputs at
// MaxOperandsAndOutputsPerFusion(). This puts an upper bound on the number of
// parameters to the kernel, working around the correctness problem.
//
// This limit is also often good for performance. In a fusion with many
// operands, each GPU thread likely has to do a lot of work, and so possibly
// uses a lot of registers, thus limiting occupancy.
//
// If the fusion is a producer/consumer fusion and instr1 is the
// consumer and instr2 is the producer, set is_consumer_producer_fusion
// to true to enable more fusion.
FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
const HloInstruction& instr2,
bool is_consumer_producer_fusion,
FusionInfoCache* cache /*=nullptr*/) {
if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
kSharedMemoryBudgetInBytes) {
return FusionDecision{}
<< "shared memory usage would be over the budget of "
<< kSharedMemoryBudgetInBytes << "B";
}
if (NumUnnestedReductions(instr1, cache) +
NumUnnestedReductions(instr2, cache) >
kMaxUnnestedReductionOutputsPerFusion) {
return FusionDecision{} << "over " << kMaxUnnestedReductionOutputsPerFusion
<< " unnested reductions in fusion";
}
// Compute the number of outputs of the (possibly multi-output) fusion node
// we're considering creating.
//
// This isn't precise; we may be off by one if
// - We're creating a multi-output fusion out of two non-MOFs. Creating a
// MOF adds a new buffer, namely, the tuple buffer.
// - We're merging two MOFs. In this case, we should count the tuple buffer
// only once.
// - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
// `a`. In this case the result of `a` is not part of the output of the
// fusion.
//
// But because this is a heuristic and our limit
// MaxOperandsAndOutputsPerFusion() is a large value (so +/- 1 doesn't make a
// big difference), we ignore this small inaccuracy in favor of simplicity.
int64_t num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
ShapeUtil::SubshapeCount(instr2.shape());
// The new fusion will have no more operands and outputs than
// producer_operands + consumer_operands - 1 + num_output_buffers
// (minus one because we may be fusing a producer->consumer edge between `a`
// and `b`).
//
// This fact may be enough to let us avoid having to compute the true total
// number of operands, which can be expensive.
if (instr1.operand_count() + instr2.operand_count() - 1 +
num_output_buffers <=
MaxOperandsAndOutputsPerFusion()) {
return {};
} else {
VLOG(5) << "Operand count of "
<< "(" << instr1.ToString() << " ) = " << instr1.operand_count()
<< " and ( " << instr2.ToString()
<< " ) = " << instr2.operand_count()
<< " and num_output_buffers = " << num_output_buffers
<< " is bigger than the bound of "
<< MaxOperandsAndOutputsPerFusion();
}
// Compute the precise number of operands to the new fusion.
absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
instr1.operands().end());
operands.insert(instr2.operands().begin(), instr2.operands().end());
// If there's an edge between `a` and `b`, don't count it: We're fusing that
// producer -> consumer relationship.
operands.erase(&instr1);
operands.erase(&instr2);
// If we generate the same numbers of inputs and outputs as
// before, it won't be bigger after fusion. So accept the fusion.
// As this is a consumer_producer fusion, this does not change the
// consumer numbers of output. So no need to check it.
if (is_consumer_producer_fusion &&
operands.size() <= instr1.operands().size()) {
return {};
}
// Does the new fusion have more operands and outputs than the max?
if (operands.size() + num_output_buffers > MaxOperandsAndOutputsPerFusion()) {
return "Number of operands and output buffers is larger than allowed "
"budget per fusion";
}
return {};
}
bool CreatesNestedLoop(const HloInstruction& producer,
const HloInstruction& consumer) {
// If producer does not have an instruction that codegens a loop then there is
// nothing to do.
auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
if (producer.opcode() != HloOpcode::kFusion) {
return IfFusedReadsElementsMultipleTimes(producer);
}
for (const auto& instr : producer.fused_instructions()) {
if (IfFusedReadsElementsMultipleTimes(*instr)) {
return true;
}
}
return false;
};
if (!producer_has_loop_codegen(producer)) {
return false;
}
// If consumer is a non-fusion instruction then we have to check if it
// generates a loop.
if (consumer.opcode() != HloOpcode::kFusion) {
return IfFusedReadsElementsMultipleTimes(consumer);
}
// If consumer is a fusion then we have to check if the output of producer is
// used directly or indirectly as an input to an HLO instruction that
// generates a loop, i.e. there is a path in the graph from an operand
// corresponding to the producer to an HLO instruction generating a loop in
// the consumer.
for (const HloInstruction* operand : consumer.operands()) {
if (operand != &producer) {
continue;
}
const HloInstruction* root =
consumer.fused_instructions_computation()->parameter_instruction(
consumer.operand_index(operand));
std::stack<const HloInstruction*> dfs;
dfs.push(root);
absl::flat_hash_set<const HloInstruction*> visited;
while (!dfs.empty()) {
const HloInstruction* cur = dfs.top();
dfs.pop();
if (visited.contains(cur)) {
continue;
}
visited.insert(cur);
if (IfFusedReadsElementsMultipleTimes(*cur)) {
return true;
}
for (const auto& user : cur->users()) {
if (visited.contains(user)) {
continue;
}
dfs.push(user);
}
}
}
return false;
}
bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
// with any other instruction.
// Note that scatter cannot be the root of a multi-output fusion because
// its emitter doesn't support it.
return instr.IsFusible() &&
(IsInputFusibleReduction(instr) ||
instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
instr.IsElementwise());
}
HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
const HloInstruction& consumer) {
return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
: HloInstruction::FusionKind::kLoop;
}
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
const HloInstruction& consumer) {
return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
// Skip GTE.
return IsConsumerTheOnlyNonRootUser(*user, consumer);
}
if (user == &consumer) {
// `user` is `consumer`.
return true;
}
if (user == user->parent()->root_instruction()) {
// Consumed by ROOT.
return true;
}
return false;
});
}
size_t GetInstrCountOfFusible(const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return 1;
} else {
return instr.fused_instruction_count();
}
}
absl::InlinedVector<const HloInstruction*, 2> GetOutputsOfFusible(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return {&instr};
}
HloInstruction* root = instr.fused_expression_root();
if (root->opcode() != HloOpcode::kTuple) {
return {root};
} else {
auto v = root->operands();
return absl::InlinedVector<const HloInstruction*, 2>(v.begin(), v.end());
}
}
size_t GetOutputSizeOfFusible(const HloInstruction& instr) {
if (!instr.IsMultiOutputFusion()) {
return 1;
}
const HloInstruction* root = instr.fused_expression_root();
return ShapeUtil::TupleElementCount(root->shape());
}
} // namespace gpu
} // namespace xla