Add TPU support to reset_memory_stats.
`tf.config.experimental.reset_memory_stats(device)` resets the device's peak memory to its current memory. Changes in CL will make reset_memory_stats support TPU, in addition to GPU, which is already supported.
PiperOrigin-RevId: 377394433
Change-Id: Ic050e4008dcef2394f601cffe8e8df682aca921e
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 07c466c..d805778 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -73,6 +73,13 @@
return tf_stats;
}
+bool XlaDeviceAllocator::ClearStats() {
+ if (!stream_executor_->SynchronizeAllActivity()) {
+ return false;
+ }
+ return stream_executor_->ClearAllocatorStats();
+}
+
XlaDeviceContext::XlaDeviceContext(
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 5689e81..12ef35f 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -42,6 +42,7 @@
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
absl::optional<AllocatorStats> GetStats() override;
+ bool ClearStats() override;
private:
// The stream executor of the device.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 7032a02..ce0d195 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -283,8 +283,8 @@
virtual absl::optional<AllocatorStats> GetStats() { return absl::nullopt; }
// If implemented, clears the internal stats except for the `in_use` fields
- // and set the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
- // true if implemented.
+ // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
+ // true if implemented.
//
// REQUIRES: GetStats is overridden.
virtual bool ClearStats() TF_MUST_USE_RESULT { return false; }
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index ac159f7..6bdb97b 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -565,6 +565,7 @@
return context.context().get_memory_info(device)
+# TODO(b/189498350): Unify the behavior of CPU, GPU and TPU.
@tf_export('config.experimental.reset_memory_stats')
def reset_memory_stats(device):
"""Resets the tracked memory stats for the chosen device.
@@ -588,12 +589,13 @@
... peak2 = tf.config.experimental.get_memory_info('GPU:0')['peak']
... assert peak2 < peak1 # tf.float32 consumes less memory than tf.float64.
- Currently raises an exception for the CPU.
+ Currently only supports GPU and TPU. If called on a CPU device, an exception
+ will be raised.
Args:
- device: Device string to reset the memory stats, e.g. `"GPU:0"`. See
- https://www.tensorflow.org/api_docs/python/tf/device for specifying device
- strings.
+ device: Device string to reset the memory stats, e.g. `"GPU:0"`, `"TPU:0"`.
+ See https://www.tensorflow.org/api_docs/python/tf/device for specifying
+ device strings.
Raises:
ValueError: No device found with the device name, like '"nonexistent"'.
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index 811eba3..44d190a 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -644,19 +644,21 @@
self.assertGreaterEqual(peak3, peak2)
self.assertGreaterEqual(peak3, config.get_memory_info(device)['current'])
- @test_util.run_gpu_only
+ @test_util.run_gpu_or_tpu
@reset_eager
- def testResetMemoryStats(self):
- x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
- config.reset_memory_stats('GPU:0')
- info1 = config.get_memory_info('GPU:0')
+ def testResetMemoryStats(self, device_type):
+ device = f'{device_type}:0'
+ with ops.device(device):
+ x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
+ config.reset_memory_stats(device)
+ info1 = config.get_memory_info(device)
self.assertGreaterEqual(info1['peak'], 4 * 1000 * 1000)
self.assertGreaterEqual(info1['peak'], info1['current'])
self.assertGreater(info1['current'], 0)
del x # With CPython, causes tensor memory to be immediately freed
- config.reset_memory_stats('GPU:0')
- info2 = config.get_memory_info('GPU:0')
+ config.reset_memory_stats(device)
+ info2 = config.get_memory_info(device)
self.assertLess(info2['peak'], info1['peak'])
@reset_eager
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 4373380..788305e 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -377,6 +377,13 @@
return absl::nullopt;
}
+ // If implemented, clears the internal stats except for the `in_use` fields
+ // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
+ // true if implemented.
+ //
+ // REQUIRES: GetAllocatorStats is overridden.
+ virtual bool ClearAllocatorStats() { return false; }
+
// Clears the compilation cache from volatile memory. Returns OK if no
// compilation cache exists or if clearing the compilation cache is
// unsupported. Caches in non-volatile storage are unaffected.
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 7c581b4..d4a13e3 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -903,6 +903,10 @@
return implementation_->GetAllocatorStats();
}
+bool StreamExecutor::ClearAllocatorStats() {
+ return implementation_->ClearAllocatorStats();
+}
+
template <typename TraceCallT, typename... ArgsT>
void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
if (tracing_enabled_) {
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 636fdce..533fe02 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -542,6 +542,10 @@
// Return allocator statistics.
absl::optional<AllocatorStats> GetAllocatorStats();
+ // Clears the internal stats except for the `in_use` fields
+ // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
+ bool ClearAllocatorStats();
+
// Return an allocator which delegates to this stream executor for memory
// allocation.
StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }