| commit | e34470b01eed193dc06ffdc552ff2397b75aed6a | [log] [tgz] |
|---|---|---|
| author | A. Unique TensorFlower <gardener@tensorflow.org> | Tue Mar 10 10:29:22 2020 -0700 |
| committer | TensorFlower Gardener <gardener@tensorflow.org> | Tue Mar 10 10:32:52 2020 -0700 |
| tree | 7450d2e1900d790459f08daa00f9ac8fa44a09cc | |
| parent | fc12e313b69b19bf8d18b83c302a1704727fe9d1 [diff] |
Adds a max_rematerialized_block_size field. PiperOrigin-RevId: 300122933 Change-Id: Ie1acd1aa0deca19e881d7638e72eaec5bf9430d1
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 21be421..bfc6769 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1648,6 +1648,8 @@ } else { // Found a valid block. Reset to start looking for single instructions // again. + max_rematerialized_block_size_ = + std::max(max_rematerialized_block_size_, max_block_size); changed = true; min_block_size = 1; max_block_size = 1;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index d1c4b8b..72221fa 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -180,6 +180,10 @@ // dead. Hence, no net instructions were added. int64 net_instructions_added_ = 0; + // Size of the largest block that has been rematerialized. This is actually an + // upper bound (within a factor of 2) on the block size. + int max_rematerialized_block_size_ = 0; + RematerializationMode mode_; };