Add TPU support to reset_memory_stats.

`tf.config.experimental.reset_memory_stats(device)` resets the device's peak memory to its current memory. Changes in CL will make reset_memory_stats support TPU, in addition to GPU, which is already supported.

PiperOrigin-RevId: 377394433
Change-Id: Ic050e4008dcef2394f601cffe8e8df682aca921e
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 07c466c..d805778 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -73,6 +73,13 @@
   return tf_stats;
 }
 
+bool XlaDeviceAllocator::ClearStats() {
+  if (!stream_executor_->SynchronizeAllActivity()) {
+    return false;
+  }
+  return stream_executor_->ClearAllocatorStats();
+}
+
 XlaDeviceContext::XlaDeviceContext(
     std::shared_ptr<se::Stream> compute_stream,
     std::shared_ptr<se::Stream> host_to_device_stream,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 5689e81..12ef35f 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -42,6 +42,7 @@
   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
   void DeallocateRaw(void* ptr) override;
   absl::optional<AllocatorStats> GetStats() override;
+  bool ClearStats() override;
 
  private:
   // The stream executor of the device.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 7032a02..ce0d195 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -283,8 +283,8 @@
   virtual absl::optional<AllocatorStats> GetStats() { return absl::nullopt; }
 
   // If implemented, clears the internal stats except for the `in_use` fields
-  // and set the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
-  // true if implemented.
+  // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
+  //  true if implemented.
   //
   // REQUIRES: GetStats is overridden.
   virtual bool ClearStats() TF_MUST_USE_RESULT { return false; }
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index ac159f7..6bdb97b 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -565,6 +565,7 @@
   return context.context().get_memory_info(device)
 
 
+# TODO(b/189498350): Unify the behavior of CPU, GPU and TPU.
 @tf_export('config.experimental.reset_memory_stats')
 def reset_memory_stats(device):
   """Resets the tracked memory stats for the chosen device.
@@ -588,12 +589,13 @@
   ...   peak2 = tf.config.experimental.get_memory_info('GPU:0')['peak']
   ...   assert peak2 < peak1  # tf.float32 consumes less memory than tf.float64.
 
-  Currently raises an exception for the CPU.
+  Currently only supports GPU and TPU. If called on a CPU device, an exception
+  will be raised.
 
   Args:
-    device: Device string to reset the memory stats, e.g. `"GPU:0"`. See
-      https://www.tensorflow.org/api_docs/python/tf/device for specifying device
-        strings.
+    device: Device string to reset the memory stats, e.g. `"GPU:0"`, `"TPU:0"`.
+      See https://www.tensorflow.org/api_docs/python/tf/device for specifying
+      device strings.
 
   Raises:
     ValueError: No device found with the device name, like '"nonexistent"'.
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index 811eba3..44d190a 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -644,19 +644,21 @@
     self.assertGreaterEqual(peak3, peak2)
     self.assertGreaterEqual(peak3, config.get_memory_info(device)['current'])
 
-  @test_util.run_gpu_only
+  @test_util.run_gpu_or_tpu
   @reset_eager
-  def testResetMemoryStats(self):
-    x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
-    config.reset_memory_stats('GPU:0')
-    info1 = config.get_memory_info('GPU:0')
+  def testResetMemoryStats(self, device_type):
+    device = f'{device_type}:0'
+    with ops.device(device):
+      x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
+    config.reset_memory_stats(device)
+    info1 = config.get_memory_info(device)
     self.assertGreaterEqual(info1['peak'], 4 * 1000 * 1000)
     self.assertGreaterEqual(info1['peak'], info1['current'])
     self.assertGreater(info1['current'], 0)
 
     del x  # With CPython, causes tensor memory to be immediately freed
-    config.reset_memory_stats('GPU:0')
-    info2 = config.get_memory_info('GPU:0')
+    config.reset_memory_stats(device)
+    info2 = config.get_memory_info(device)
     self.assertLess(info2['peak'], info1['peak'])
 
   @reset_eager
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 4373380..788305e 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -377,6 +377,13 @@
     return absl::nullopt;
   }
 
+  // If implemented, clears the internal stats except for the `in_use` fields
+  // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
+  // true if implemented.
+  //
+  // REQUIRES: GetAllocatorStats is overridden.
+  virtual bool ClearAllocatorStats() { return false; }
+
   // Clears the compilation cache from volatile memory. Returns OK if no
   // compilation cache exists or if clearing the compilation cache is
   // unsupported. Caches in non-volatile storage are unaffected.
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 7c581b4..d4a13e3 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -903,6 +903,10 @@
   return implementation_->GetAllocatorStats();
 }
 
+bool StreamExecutor::ClearAllocatorStats() {
+  return implementation_->ClearAllocatorStats();
+}
+
 template <typename TraceCallT, typename... ArgsT>
 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&...args) {
   if (tracing_enabled_) {
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 636fdce..533fe02 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -542,6 +542,10 @@
   // Return allocator statistics.
   absl::optional<AllocatorStats> GetAllocatorStats();
 
+  // Clears the internal stats except for the `in_use` fields
+  // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`.
+  bool ClearAllocatorStats();
+
   // Return an allocator which delegates to this stream executor for memory
   // allocation.
   StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }