Enables running HloRematerialization as a compress-only or recompute-only pass.

PiperOrigin-RevId: 286519401
Change-Id: Iaf720be722e1aa86c3deb5816aabbee535f90c37
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 445a3ea..5d38bbe 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -370,7 +370,8 @@
       const HloRematerialization::ShapeSizeFunction& size_function,
       const HloRematerialization::CompactShapeFunction& compact_shape_function,
       const TuplePointsToAnalysis& points_to_analysis,
-      const InstructionList& instruction_list);
+      const InstructionList& instruction_list,
+      HloRematerialization::RematerializationMode mode);
 
   // Starts the placement of the given instruction. This adds the sizes of the
   // LogicalBuffers defined by the instruction to the current memory
@@ -607,6 +608,7 @@
   // between the calling of BeginInstruction and EndInstruction.
   Item* in_progress_item_ = nullptr;
 
+  HloRematerialization::RematerializationMode mode_;
   // All buffers in the computation.
   std::vector<Buffer> buffers_;
 };
@@ -616,11 +618,13 @@
     const HloRematerialization::ShapeSizeFunction& size_function,
     const HloRematerialization::CompactShapeFunction& compact_shape_function,
     const TuplePointsToAnalysis& points_to_analysis,
-    const InstructionList& instruction_list)
+    const InstructionList& instruction_list,
+    HloRematerialization::RematerializationMode mode)
     : computation_(computation),
       instruction_list_(instruction_list),
       size_function_(size_function),
-      compact_shape_function_(compact_shape_function) {
+      compact_shape_function_(compact_shape_function),
+      mode_(mode) {
   PointsToSet::BufferSet live_out_set =
       points_to_analysis.GetPointsToSet(computation_->root_instruction())
           .CreateFlattenedSet();
@@ -1155,7 +1159,10 @@
       continue;
     }
 
-    if (item->buffers_output.size() == 1) {
+    if (item->buffers_output.size() == 1 &&
+        (mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
+         mode_ == HloRematerialization::RematerializationMode::
+                      kRecomputeAndCompress)) {
       // Only consider compressing single output instruction.
       const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
 
@@ -1196,6 +1203,11 @@
       continue;
     }
 
+    // Do not consider recomputation in compress-only mode.
+    if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
+      continue;
+    }
+
     const int64 memory_reduced = MemoryReducedIfRematerialized(item);
 
     if (memory_reduced > 0) {
@@ -1370,7 +1382,7 @@
   InstructionList instruction_list(order);
   MemoryUsageTracker tracker(computation, size_function_,
                              compact_shape_function_, *points_to_analysis_,
-                             instruction_list);
+                             instruction_list, mode_);
   int64 peak_memory = tracker.memory_usage();
   for (auto* item = instruction_list.first(); item != nullptr;
        item = instruction_list.next(item)) {
@@ -1412,9 +1424,9 @@
   CHECK(!ContainsKey(rematerialized_computations_, computation));
 
   InstructionList instruction_list(schedule->sequence(computation));
-  MemoryUsageTracker memory_tracker(computation, size_function_,
-                                    compact_shape_function_,
-                                    *points_to_analysis_, instruction_list);
+  MemoryUsageTracker memory_tracker(
+      computation, size_function_, compact_shape_function_,
+      *points_to_analysis_, instruction_list, mode_);
   bool changed = false;
 
   // If the rematerialization makes the source instruction dead, then the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 9ab34b4..69cdc84 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -49,6 +49,13 @@
     int64 after_bytes;
   };
 
+  // Mode in which the rematerialization algorithm should be run.
+  enum class RematerializationMode {
+    kRecomputeOnly,        // Only consider the kCompress RematStrategy.
+    kCompressOnly,         // Only consider the kRecompute RematStrategy.
+    kRecomputeAndCompress  // Consider both kRecompute and kRemat.
+  };
+
   static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
 
   // Constructor parameters:
@@ -69,13 +76,15 @@
   explicit HloRematerialization(
       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
       RematerializationSizes* sizes,
-      CompactShapeFunction compact_shape_function = nullptr)
+      CompactShapeFunction compact_shape_function = nullptr,
+      RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
       : size_function_(size_function),
         memory_limit_bytes_(memory_limit_bytes),
         sizes_(sizes),
         compact_shape_function_(compact_shape_function == nullptr
                                     ? DefaultCompactShapeFunction
-                                    : std::move(compact_shape_function)) {}
+                                    : std::move(compact_shape_function)),
+        mode_(mode) {}
   ~HloRematerialization() override = default;
 
   absl::string_view name() const override { return "rematerialization"; }
@@ -152,6 +161,8 @@
   // uses of the original instruction and the original instruction is
   // dead. Hence, no net instructions were added.
   int64 net_instructions_added_ = 0;
+
+  RematerializationMode mode_;
 };
 
 }  // namespace xla