[XLA] Use cost analysis to order buffers to be assigned in mem space assmt

Using cost analysis, we can estimate the "memory boundedness" of an HLO
instruction and prioritize those that will benefit the most of being placed in
alternate (fast) memory space.

PiperOrigin-RevId: 277595877
Change-Id: I78d186e518195c73edfb7e56fad805ac5bbdbb56
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index fc7f36e..65b813b 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -478,6 +478,54 @@
   return result;
 }
 
+GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
+    int64 alignment, Type type)
+    : alignment_(alignment) {
+  if (type == kTemporal) {
+    buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
+  } else {
+    CHECK(type == kSpatial);
+    buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
+  }
+}
+
+GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
+GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
+  return [&](const BufferInterval& x, const BufferInterval& y) {
+    int64 x_end = x.end;
+    for (auto colocation : GetTransitiveColocations(x)) {
+      x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
+    }
+
+    int64 y_end = y.end;
+    for (auto colocation : GetTransitiveColocations(y)) {
+      y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
+    }
+
+    if (x_end - x.start != y_end - y.start) {
+      return x_end - x.start > y_end - y.start;
+    }
+
+    if (x.size != y.size) {
+      return x.size > y.size;
+    }
+    return x.buffer->id() < y.buffer->id();
+  };
+}
+
+/*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
+GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
+  return [&](const BufferInterval& x, const BufferInterval& y) {
+    if (x.size != y.size) {
+      return x.size > y.size;
+    }
+    if (x.end - x.start != y.end - y.start) {
+      return x.end - x.start > y.end - y.start;
+    }
+    return x.buffer->id() < y.buffer->id();
+  };
+}
+
 void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
                                             int64 size) {
   // Degenerate case: 0-sized buffers are always allocated at offset 0.
@@ -627,47 +675,7 @@
   for (auto& entry : buffer_intervals_) {
     sorted_buffer_intervals.push_back(entry.second);
   }
-  if (type_ == kTemporal) {
-    // Sort by live-range. A live range is defined by the range between the
-    // start of the first buffer and the end of the last co-located
-    // buffer. There could be "holes" in the live ranges of each co-located
-    // buffers, but in this heuristics we think they are contiguous.
-    absl::c_sort(sorted_buffer_intervals, [&](const BufferInterval& x,
-                                              const BufferInterval& y) {
-      int64 x_end = x.end;
-      for (auto colocation : GetTransitiveColocations(x)) {
-        x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
-      }
-
-      int64 y_end = y.end;
-      for (auto colocation : GetTransitiveColocations(y)) {
-        y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
-      }
-
-      if (x_end - x.start != y_end - y.start) {
-        return x_end - x.start > y_end - y.start;
-      }
-
-      if (x.size != y.size) {
-        return x.size > y.size;
-      }
-      return x.buffer->id() < y.buffer->id();
-    });
-  } else {
-    // Sort by spatial size. We don't look at co-locates as they should have the
-    // same size.
-    CHECK(type_ == kSpatial);
-    absl::c_sort(sorted_buffer_intervals,
-                 [&](const BufferInterval& x, const BufferInterval& y) {
-                   if (x.size != y.size) {
-                     return x.size > y.size;
-                   }
-                   if (x.end - x.start != y.end - y.start) {
-                     return x.end - x.start > y.end - y.start;
-                   }
-                   return x.buffer->id() < y.buffer->id();
-                 });
-  }
+  absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
 
   return sorted_buffer_intervals;
 }
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index d8f5996..7e9ccd5 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -296,20 +296,6 @@
     kTemporal,
   };
 
-  explicit GlobalDecreasingSizeBestFitHeap(int64 alignment,
-                                           Type type = kSpatial)
-      : alignment_(alignment), type_(type) {}
-  ~GlobalDecreasingSizeBestFitHeap() override {}
-
-  void Alloc(const HloValue* buffer, int64 size) override;
-  void Free(const HloValue* buffer, int64 size) override;
-
-  void ShareWith(const HloValue* buffer, const HloValue* share_with,
-                 int64 size) override;
-
-  Result Finish() override;
-
- protected:
   // BufferInterval stores a buffer's size and time interval.
   struct BufferInterval {
     const HloValue* buffer;
@@ -327,6 +313,27 @@
     bool need_allocation;
   };
 
+  // Comparison function that is used to store buffer intervals.
+  using BufferIntervalCompare =
+      std::function<bool(const BufferInterval&, const BufferInterval&)>;
+
+  explicit GlobalDecreasingSizeBestFitHeap(int64 alignment,
+                                           Type type = kSpatial);
+  ~GlobalDecreasingSizeBestFitHeap() override {}
+
+  void Alloc(const HloValue* buffer, int64 size) override;
+  void Free(const HloValue* buffer, int64 size) override;
+
+  void ShareWith(const HloValue* buffer, const HloValue* share_with,
+                 int64 size) override;
+
+  Result Finish() override;
+
+  // Return a BufferIntervalCompare function that sort by spatial size. We don't
+  // look at co-locates as they should have the same size.
+  static BufferIntervalCompare GetSpatialBufferIntervalCompare();
+
+ protected:
   // Node in BufferIntervalTree that stores the alloc and free times of a
   // buffer, and the chunk assigned to it.
   struct BufferIntervalTreeNode {
@@ -367,7 +374,7 @@
     int64 heap_size;
   };
 
-  // Returns the buffer intervals sorted according to type_.
+  // Returns the buffer intervals sorted according to buffer_interval_compare_.
   std::vector<BufferInterval> GetSortedBufferIntervals() const;
 
   // These two methods below are exposed to other heap algorithms that inherit
@@ -385,12 +392,19 @@
   // Adds the buffer and the chunk to the result chunk map.
   virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk);
 
+  // Return a BufferIntervalCompare function that sorts by live ranges.  A live
+  // range is defined by the range between the start of the first buffer and the
+  // end of the last co-located buffer.  There could be "holes" in the live
+  // ranges of each co-located buffers, but in this heuristics we think they are
+  // contiguous.
+  BufferIntervalCompare GetTemporalBufferIntervalCompare() const;
+
   absl::flat_hash_map<const HloValue*, BufferInterval> buffer_intervals_;
   Result result_;
+  BufferIntervalCompare buffer_interval_compare_;
 
  private:
   int64 alignment_;
-  Type type_;
 
   // The current time represented as an integer. It increments by 1 at each
   // Alloc or Free call.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 7118fe8..bea731b 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -751,9 +751,71 @@
   return max_copies;
 }
 
+/*static*/ MemorySpaceAssignment::BufferIntervalCompare
+MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
+    const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
+  return [&](const BufferInterval& x, const BufferInterval& y) {
+    // Returns a heuristic value that captures how much putting this tensor to
+    // the alternate memory would help if the op is memory bound, or otherwise
+    // how far off is the op to memory boundedness. The larger this number, the
+    // higher priority it will be placed in the alternate memory.
+    auto get_alternate_mem_benefit =
+        [&](const HloInstruction& instruction,
+            float elapsed_time_due_to_alternate_mem) {
+          float elapsed_time_due_to_compute =
+              cost_analysis.GetInstructionElapsedDueToCompute(instruction);
+          float elapsed_time_due_to_memory =
+              cost_analysis.GetInstructionElapsedDueToMemory(instruction);
+          if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
+            // Memory bound, return how much alternate memory is better.
+            return elapsed_time_due_to_memory -
+                   elapsed_time_due_to_alternate_mem;
+          } else {
+            // Compute bound, return how far off are we to memory boundedness.
+            return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
+          }
+        };
+
+    auto get_memory_boundedness = [&](const BufferInterval& interval) {
+      const HloInstruction& defining_instruction =
+          *interval.buffer->defining_instruction();
+      float alternate_mem_benefit = get_alternate_mem_benefit(
+          defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory(
+                                    defining_instruction,
+                                    /*operand_in_alternate_mem=*/{},
+                                    /*output_in_alternate_mem=*/true));
+      for (const HloUse& use : interval.buffer->uses()) {
+        float use_alternate_mem_benefit = get_alternate_mem_benefit(
+            *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory(
+                                  *use.instruction, use.operand_number));
+        // If the benefit is positive (memory bound), add it to this buffer's
+        // benefit. If the benefit is negative (compute bound), calculate the
+        // maximum.
+        if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
+          alternate_mem_benefit += use_alternate_mem_benefit;
+        } else {
+          alternate_mem_benefit =
+              std::max(alternate_mem_benefit, use_alternate_mem_benefit);
+        }
+      }
+      return alternate_mem_benefit;
+    };
+
+    float x_memory_boundedness = get_memory_boundedness(x);
+    float y_memory_boundedness = get_memory_boundedness(y);
+    if (x_memory_boundedness != y_memory_boundedness) {
+      return x_memory_boundedness > y_memory_boundedness;
+    }
+    // Tie-break if the memory boundedness is the same.
+    return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()(
+        x, y);
+  };
+}
+
 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
 MemorySpaceAssignment::Run(
     HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes,
+    absl::optional<BufferIntervalCompare> buffer_interval_compare,
     PrefetchIntervalPicker* prefetch_interval_picker,
     int64 alternate_memory_space_alignment_in_bytes,
     BufferValue::SizeFunction size_fn,
@@ -773,7 +835,7 @@
                                         entry_computation));
   auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
       &memory_space_assignment.allocation_map_, max_size_in_bytes,
-      prefetch_interval_picker, *alias_analysis,
+      buffer_interval_compare, prefetch_interval_picker, *alias_analysis,
       *memory_space_assignment.hlo_live_range_,
       alternate_memory_space_alignment_in_bytes, is_allowed_in_alternate_mem,
       max_outstanding_async_copies);
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 0de74a5..dd341a9 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -229,6 +229,9 @@
 class MemorySpaceAssignment {
  public:
   using Chunk = HeapSimulator::Chunk;
+  using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval;
+  using BufferIntervalCompare =
+      GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare;
 
   // MemorySpaceAssignment uses a notion of a slow and large default memory
   // space and a fast and small alternate memory space.
@@ -395,6 +398,8 @@
   // Runs the MemorySpaceAssignment pass. alternate_memory_space is the
   // architecture-specific integer value that describes the alternate memory.
   // max_size_in_bytes is the maximum size of the alternate memory.
+  // If a buffer_interval_compare is provided, we sort the buffers using that
+  // (otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial).
   // prefetch_interval_picker determines how early and how late can prefetches
   // occur. alternate_memory_space_alignment_in_bytes is the alignment required
   // in the alternate memory space, size_fn is the size function for buffer
@@ -404,6 +409,7 @@
   // outstanding asynchronous copies, -1 for unlimited.
   static StatusOr<std::unique_ptr<PresetAssignments>> Run(
       HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes,
+      absl::optional<BufferIntervalCompare> buffer_interval_compare,
       PrefetchIntervalPicker* prefetch_interval_picker,
       int64 alternate_memory_space_alignment_in_bytes,
       BufferValue::SizeFunction size_fn,
@@ -414,6 +420,9 @@
   // module.
   static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module);
 
+  static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
+      const MemorySpaceAssignmentCostAnalysis& cost_analysis);
+
  private:
   MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space)
       : module_(module),
@@ -481,7 +490,9 @@
 
   AlternateMemoryBestFitHeap(
       MemorySpaceAssignment::AllocationMap* allocation_map,
-      int64 max_size_in_bytes, PrefetchIntervalPicker* prefetch_interval_picker,
+      int64 max_size_in_bytes,
+      absl::optional<BufferIntervalCompare> buffer_interval_compare,
+      PrefetchIntervalPicker* prefetch_interval_picker,
       const HloAliasAnalysis& alias_analysis,
       const HloLiveRange& hlo_live_range, int64 alignment,
       IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem,
@@ -493,7 +504,12 @@
         alias_analysis_(alias_analysis),
         hlo_live_range_(hlo_live_range),
         is_allowed_in_alternate_mem_(is_allowed_in_alternate_mem),
-        max_outstanding_async_copies_(max_outstanding_async_copies) {}
+        max_outstanding_async_copies_(max_outstanding_async_copies) {
+    // Override buffer interval compare if provided.
+    if (buffer_interval_compare) {
+      buffer_interval_compare_ = *buffer_interval_compare;
+    }
+  }
 
   HeapSimulator::Result Finish() override;
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 4b10b03..f064bce 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -57,8 +57,11 @@
         CostAnalysisPrefetchIntervalPicker(
             cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
             /*max_async_copy_to_overlap_ratio=*/10.0));
-    return AssignMemorySpace(module, /*max_outstanding_async_copies=*/-1,
-                             &prefetch_interval_picker);
+    return AssignMemorySpace(
+        module, /*max_outstanding_async_copies=*/-1,
+        MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
+            cost_analysis),
+        &prefetch_interval_picker);
   }
 
   std::unique_ptr<PresetAssignments> AssignMemorySpace(
@@ -67,11 +70,14 @@
     InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
         /*min_overlap_count=*/2, max_prefetch_interval);
     return AssignMemorySpace(module, max_outstanding_async_copies,
+                             /*buffer_interval_compare=*/{},
                              &prefetch_interval_picker);
   }
 
   std::unique_ptr<PresetAssignments> AssignMemorySpace(
       HloModule* module, int64 max_outstanding_async_copies,
+      absl::optional<MemorySpaceAssignment::BufferIntervalCompare>
+          buffer_interval_compare,
       PrefetchIntervalPicker* prefetch_interval_picker) {
     auto size_fn = [](const BufferValue& buffer) {
       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
@@ -93,7 +99,8 @@
     std::unique_ptr<PresetAssignments> preset_assignments =
         MemorySpaceAssignment::Run(
             module, kAlternateMemorySpace,
-            /*max_size_in_bytes=*/128, prefetch_interval_picker,
+            /*max_size_in_bytes=*/128, buffer_interval_compare,
+            prefetch_interval_picker,
             /*alternate_memory_space_alignment_in_bytes=*/8, size_fn,
             is_allowed_in_alternate_mem, max_outstanding_async_copies)
             .ValueOrDie();
@@ -1640,5 +1647,83 @@
   EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
 }
 
+TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
+  // This test is carefully crafted to force only negates to be allocated to the
+  // alternate memory. The graph consists of interleaving negate and tanh
+  // operations:
+  //
+  //        +------+      +-------+      +-----
+  //       /        \    /         \    /
+  //  negate  tanh  negate  tanh   negate  tanh
+  //             \          /  \           /
+  //              +--------+    +---------+
+  //
+  // The alternate memory is sized to fit only one f32[4,6] tensor at a time.
+  // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the
+  // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which
+  // are more memory bound.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 6});
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* p1 =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
+  HloInstruction* tanh0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
+  HloInstruction* tanh1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* tanh2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* tanh3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* tanh4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3));
+  HloInstruction* negate4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, tanh4, negate4));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation,
+                        {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2,
+                         tanh3, negate3, tanh4, negate4, add});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpaceUsingCostAnalysis(module.get());
+  // Parameters are in the default memory space.
+  EXPECT_THAT(p0, op::ShapeWithLayout(shape));
+  EXPECT_THAT(p1, op::ShapeWithLayout(shape));
+  Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
+      F32, {4, 6},
+      /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
+      kDefaultMemorySpace);
+  Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
+      F32, {4, 6},
+      /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
+      kAlternateMemorySpace);
+  // Expect only negates to be in alternate memory space.
+  EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
+  EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
+  EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
+  EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
+  EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem));
+  EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
+}
+
 }  // namespace
 }  // namespace xla