Introduce a cache to avoid O(n^2) behavior with GPU MOF.
Some models have many multi-output fusion opportunities. When we evaluate
these, we do an O(n^2) iteration over all pairs of nodes in an equivalence
class.
Before this change, we would recompute the shared memory required for each
fusion, and the number of unnested reductions in each fusion, O(n^2) times.
Computing these is tantamount to iterating over the whole fusion node, so is
especially expensive when the fusion is large.
Now we cache these values, so we compute them only O(n) times.
PiperOrigin-RevId: 394539476
Change-Id: I9b8dba935ee9aa6735b952977091f1c2b62c5311
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 4ae35d2..ad0003d 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -426,7 +426,7 @@
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
const Layout& layout, absl::Span<const int64_t> dims) {
CHECK(IsDense(layout));
- std::vector<int64_t> positions_in_layout;
+ absl::InlinedVector<int64_t, 8> positions_in_layout;
for (int64_t dim : dims) {
positions_in_layout.push_back(
PositionInContainer(layout.minor_to_major(), dim));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 4f1dde8..0addc10 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -316,9 +316,10 @@
}
// Returns shared memory usage for a given instruction in bytes.
-static int64_t SharedMemoryUsage(const HloInstruction& instr) {
+static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
// For now we are only fusing reductions.
- if (IsReductionFromOrToContiguousDimensions(instr)) {
+ if (instr.opcode() == HloOpcode::kReduce &&
+ IsReductionFromOrToContiguousDimensions(instr)) {
ReductionDimensions reduction_info =
GetReductionKindAndContiguousComponents(instr);
int64_t primitive_size =
@@ -335,7 +336,7 @@
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
- sum += SharedMemoryUsage(*hlo);
+ sum += SharedMemoryUsageNoCache(*hlo);
}
return sum;
}
@@ -343,26 +344,67 @@
return 0;
}
+static int64_t SharedMemoryUsage(const HloInstruction& instr,
+ FusionInfoCache* cache = nullptr) {
+ if (!cache) {
+ return SharedMemoryUsageNoCache(instr);
+ }
+
+ // nb: Users are only expected to call cache.Invalidate() on top-level
+ // instructions, not instructions inside fusion nodes. Therefore we can only
+ // cache top-level instructions; it would not be valid to pass the cache to
+ // SharedMemoryUsageNoCache and use the cache *within* the fusion.
+ auto it_and_inserted = cache->shared_memory_usage.emplace(&instr, -1);
+ auto it = it_and_inserted.first;
+ auto inserted = it_and_inserted.second;
+
+ if (inserted) {
+ it->second = SharedMemoryUsageNoCache(instr);
+ }
+ return it->second;
+}
+
// Codegen'ing unnested reductions requires a lot of registers, so a MOF
// combining many of those runs a high risk of spilling.
constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
// Returns the number of unnested reductions in the instruction output.
-static int64_t NumUnnestedReductions(const HloInstruction& instr) {
- if (IsReductionFromOrToContiguousDimensions(instr)) {
+static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) {
+ if (instr.opcode() == HloOpcode::kReduce &&
+ IsReductionFromOrToContiguousDimensions(instr)) {
return 1;
}
if (instr.opcode() == HloOpcode::kFusion) {
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
- sum += NumUnnestedReductions(*hlo);
+ sum += NumUnnestedReductionsNoCache(*hlo);
}
return sum;
}
return 0;
}
+static int64_t NumUnnestedReductions(const HloInstruction& instr,
+ FusionInfoCache* cache) {
+ if (!cache) {
+ return NumUnnestedReductionsNoCache(instr);
+ }
+
+ // nb: Users are only expected to call cache.Invalidate() on top-level
+ // instructions, not instructions inside fusion nodes. Therefore we can only
+ // cache top-level instructions; it would not be valid to pass the cache to
+ // NumUnnestedReductionsNoCache and use the cache *within* the fusion.
+ auto it_and_inserted = cache->num_unnested_reductions.emplace(&instr, -1);
+ auto it = it_and_inserted.first;
+ auto inserted = it_and_inserted.second;
+
+ if (inserted) {
+ it->second = NumUnnestedReductionsNoCache(instr);
+ }
+ return it->second;
+}
+
// This function limits the maximum number of operands to a fusion, and the
// amount of shared memory which can be consumed by the fusion.
//
@@ -388,8 +430,9 @@
// to true to enable more fusion.
bool FusionWouldBeTooLarge(const HloInstruction& instr1,
const HloInstruction& instr2,
- bool is_consumer_producer_fusion) {
- if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
+ bool is_consumer_producer_fusion,
+ FusionInfoCache* cache /*=nullptr*/) {
+ if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
kSharedMemoryBudgetInBytes) {
VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
<< " and " << instr2.ToString() << " would be over the budget of "
@@ -397,7 +440,8 @@
return true;
}
- if (NumUnnestedReductions(instr1) + NumUnnestedReductions(instr2) >
+ if (NumUnnestedReductions(instr1, cache) +
+ NumUnnestedReductions(instr2, cache) >
kMaxUnnestedReductionOutputsPerFusion) {
VLOG(5) << "Not fusing over " << kMaxUnnestedReductionOutputsPerFusion
<< " unnested reductions in fusion";
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
index 812d687..8ae5e34 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -24,6 +24,27 @@
namespace xla {
namespace gpu {
+// Fusion passes frequently do checks across all pairs of "interesting" nodes.
+// Computing e.g. FusionWouldBeTooLarge(a, b) requires computing expensive
+// properties of `a` and `b` individually. This cache lets us avoid recomputing
+// those properties n^2 times.
+//
+// Invariant: After modifying or removing a fusion node, call Invalidate(node).
+struct FusionInfoCache {
+ public:
+ // Must be called after modifying or removing a fusion node (or other node
+ // that's part of this cache).
+ void Invalidate(const HloInstruction* instr) {
+ shared_memory_usage.erase(instr);
+ num_unnested_reductions.erase(instr);
+ }
+
+ // The rest of the members of this this class are for internal use within
+ // gpu_fusible. You shouldn't need to use them yourself.
+ absl::flat_hash_map<const HloInstruction*, int64_t> shared_memory_usage;
+ absl::flat_hash_map<const HloInstruction*, int64_t> num_unnested_reductions;
+};
+
constexpr int64_t kMaxOperandsAndOutputsPerFusion = 64;
bool IsInputFusible(const HloInstruction& instr);
@@ -64,7 +85,8 @@
// to true to enable more fusion.
bool FusionWouldBeTooLarge(const HloInstruction& instr1,
const HloInstruction& instr2,
- bool is_consumer_producer_fusion = false);
+ bool is_consumer_producer_fusion = false,
+ FusionInfoCache* cache = nullptr);
// Check if fusing producer and consumer will generate a nested loop, e.g. both
// producer and consumer are `reduce-window` HLO instructions.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 9e319e1..f1d223f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -268,7 +268,7 @@
// Whether we can/should use the unnested emitter for reduction.
static bool IsReductionFromOrToContiguousDimensionsHelper(
const Shape& operand_shape, absl::Span<int64_t const> dims_to_reduce) {
- std::vector<int64_t> dims_to_keep;
+ DimensionVector dims_to_keep;
for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
if (!absl::c_linear_search(dims_to_reduce, dim)) {
dims_to_keep.push_back(dim);
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 7ac2d98..26b8bea 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -50,7 +50,8 @@
return true;
}
-bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) {
+bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
+ FusionInfoCache* fusion_info_cache) {
// If we're fusing fusions only do it if the fusion kind matches. Loop fusions
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
@@ -76,7 +77,9 @@
return false;
}
// Do this check last, as it may be expensive.
- return !FusionWouldBeTooLarge(*instr1, *instr2);
+ return !FusionWouldBeTooLarge(*instr1, *instr2,
+ /*is_consumer_producer_fusion=*/false,
+ fusion_info_cache);
}
// We prefer multi-output fusions over other fusions over unfused ops, because
@@ -104,7 +107,8 @@
}
std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
- const HloInstruction* producer, const HloReachabilityMap& reachability) {
+ const HloInstruction* producer, const HloReachabilityMap& reachability,
+ FusionInfoCache* fusion_info_cache) {
std::vector<HloInstruction*> fusion_candidates;
// If there is only one user, and it is not a multi-output fusion node, this
// fusion possibility was already considered and rejected by the FusionMerger
@@ -145,7 +149,9 @@
VLOG(3) << producer->name() << " would introduce a cycle when fused.";
continue;
}
- if (FusionWouldBeTooLarge(*producer, *consumer)) {
+ if (FusionWouldBeTooLarge(*producer, *consumer,
+ /*is_consumer_producer_fusion=*/false,
+ fusion_info_cache)) {
VLOG(3) << producer->name() << " and " << consumer->name()
<< " would be too large of a fusion.";
continue;
@@ -191,7 +197,8 @@
reachability_ = HloReachabilityMap::Build(computation_);
}
-bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
+bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
+ FusionInfoCache* fusion_info_cache) {
if (!IsProfitableOperand(parent)) {
return false;
}
@@ -213,7 +220,7 @@
VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
!ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
- !LegalToFuse(*i, *j)) {
+ !LegalToFuse(*i, *j, fusion_info_cache)) {
++j;
continue;
}
@@ -225,6 +232,8 @@
continue;
}
VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
+ fusion_info_cache->Invalidate(*i);
+ fusion_info_cache->Invalidate(*j);
HloInstruction* remaining = *i;
HloInstruction* fused = *j;
if (fused->opcode() == HloOpcode::kFusion) {
@@ -260,6 +269,7 @@
return Status::OK();
};
+ FusionInfoCache fusion_info_cache;
while (!defs_before_uses.empty()) {
// Traverse the HLO in uses-before-defs order by removing instruction from
// the back of the vector.
@@ -272,7 +282,7 @@
continue;
}
// First, fuse the consumer ops of the current op, which are siblings.
- if (FuseSiblings(/*parent=*/producer)) {
+ if (FuseSiblings(/*parent=*/producer, &fusion_info_cache)) {
changed = true;
}
// Second, perform producer-consumer multi-output fusion. This order will
@@ -280,7 +290,7 @@
// multi-output fusion will occur before the current op in the order of
// traversal, and hence, not get into the way of subsequent fusion attempts.
const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
- producer, *reachability_);
+ producer, *reachability_, &fusion_info_cache);
auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
if (consumer_for_fusion == nullptr) {
continue;
@@ -292,6 +302,9 @@
continue;
}
changed = true;
+ fusion_info_cache.Invalidate(producer);
+ fusion_info_cache.Invalidate(consumer_for_fusion);
+
if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
<< consumer_for_fusion->name();
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 78cda6c..617f726 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -21,6 +21,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@@ -96,7 +97,7 @@
StatusOr<bool> Run(HloModule* module) override;
private:
- bool FuseSiblings(HloInstruction* parent);
+ bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache);
StatusOr<bool> DoMultiOutputFusion();