Enables running HloRematerialization as a compress-only or recompute-only pass.
PiperOrigin-RevId: 286519401
Change-Id: Iaf720be722e1aa86c3deb5816aabbee535f90c37
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 445a3ea..5d38bbe 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -370,7 +370,8 @@
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
- const InstructionList& instruction_list);
+ const InstructionList& instruction_list,
+ HloRematerialization::RematerializationMode mode);
// Starts the placement of the given instruction. This adds the sizes of the
// LogicalBuffers defined by the instruction to the current memory
@@ -607,6 +608,7 @@
// between the calling of BeginInstruction and EndInstruction.
Item* in_progress_item_ = nullptr;
+ HloRematerialization::RematerializationMode mode_;
// All buffers in the computation.
std::vector<Buffer> buffers_;
};
@@ -616,11 +618,13 @@
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
- const InstructionList& instruction_list)
+ const InstructionList& instruction_list,
+ HloRematerialization::RematerializationMode mode)
: computation_(computation),
instruction_list_(instruction_list),
size_function_(size_function),
- compact_shape_function_(compact_shape_function) {
+ compact_shape_function_(compact_shape_function),
+ mode_(mode) {
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
@@ -1155,7 +1159,10 @@
continue;
}
- if (item->buffers_output.size() == 1) {
+ if (item->buffers_output.size() == 1 &&
+ (mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
+ mode_ == HloRematerialization::RematerializationMode::
+ kRecomputeAndCompress)) {
// Only consider compressing single output instruction.
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
@@ -1196,6 +1203,11 @@
continue;
}
+ // Do not consider recomputation in compress-only mode.
+ if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
+ continue;
+ }
+
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
if (memory_reduced > 0) {
@@ -1370,7 +1382,7 @@
InstructionList instruction_list(order);
MemoryUsageTracker tracker(computation, size_function_,
compact_shape_function_, *points_to_analysis_,
- instruction_list);
+ instruction_list, mode_);
int64 peak_memory = tracker.memory_usage();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
@@ -1412,9 +1424,9 @@
CHECK(!ContainsKey(rematerialized_computations_, computation));
InstructionList instruction_list(schedule->sequence(computation));
- MemoryUsageTracker memory_tracker(computation, size_function_,
- compact_shape_function_,
- *points_to_analysis_, instruction_list);
+ MemoryUsageTracker memory_tracker(
+ computation, size_function_, compact_shape_function_,
+ *points_to_analysis_, instruction_list, mode_);
bool changed = false;
// If the rematerialization makes the source instruction dead, then the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 9ab34b4..69cdc84 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -49,6 +49,13 @@
int64 after_bytes;
};
+ // Mode in which the rematerialization algorithm should be run.
+ enum class RematerializationMode {
+ kRecomputeOnly, // Only consider the kCompress RematStrategy.
+ kCompressOnly, // Only consider the kRecompute RematStrategy.
+ kRecomputeAndCompress // Consider both kRecompute and kRemat.
+ };
+
static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
// Constructor parameters:
@@ -69,13 +76,15 @@
explicit HloRematerialization(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
RematerializationSizes* sizes,
- CompactShapeFunction compact_shape_function = nullptr)
+ CompactShapeFunction compact_shape_function = nullptr,
+ RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
: size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes),
compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction
- : std::move(compact_shape_function)) {}
+ : std::move(compact_shape_function)),
+ mode_(mode) {}
~HloRematerialization() override = default;
absl::string_view name() const override { return "rematerialization"; }
@@ -152,6 +161,8 @@
// uses of the original instruction and the original instruction is
// dead. Hence, no net instructions were added.
int64 net_instructions_added_ = 0;
+
+ RematerializationMode mode_;
};
} // namespace xla