Add function tf.config.experimental.reset_memory_stats.

It allows the tracked peak memory to be reset in the middle of a program, which is useful for test and debug.

PiperOrigin-RevId: 374239680
Change-Id: I1db1255f2f6f0acea5c3832df530fc4db7f56f08
diff --git a/RELEASE.md b/RELEASE.md
index 546aae1..c9228b5 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -57,6 +57,10 @@
         `tf.saved_model.SaveOption(experimental_custom_gradients=True)` to
         enable this feature.
 
+*   TF Core:
+    *   Added `tf.config.experimental.reset_memory_stats` to reset the tracked
+        peak memory returned by `tf.config.experimental.get_memory_info`.
+
 ## Bug Fixes and Other Changes
 
 *<SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index a4965bf..8870cee 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -1158,11 +1158,12 @@
   return stats_;
 }
 
-void BFCAllocator::ClearStats() {
+bool BFCAllocator::ClearStats() {
   mutex_lock l(lock_);
   stats_.num_allocs = 0;
   stats_.peak_bytes_in_use = stats_.bytes_in_use;
   stats_.largest_alloc_size = 0;
+  return true;
 }
 
 std::array<BFCAllocator::BinDebugInfo, BFCAllocator::kNumBins>
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index a5f426c..67e9194 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -75,7 +75,7 @@
 
   absl::optional<AllocatorStats> GetStats() override;
 
-  void ClearStats() override;
+  bool ClearStats() override;
 
   void SetTimingCounter(SharedCounter* sc) { timing_counter_ = sc; }
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
index 3c4778a..2344da6 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
@@ -294,12 +294,13 @@
   return *stats_;
 }
 
-void GpuCudaMallocAsyncAllocator::ClearStats() {
-  if (!stats_) return;
+bool GpuCudaMallocAsyncAllocator::ClearStats() {
+  if (!stats_) return false;
   mutex_lock l(lock_);
   stats_->num_allocs = 0;
   stats_->peak_bytes_in_use = stats_->bytes_in_use;
   stats_->largest_alloc_size = 0;
+  return true;
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
index 8c76139..ccc9f14 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
@@ -81,7 +81,7 @@
 
   absl::optional<AllocatorStats> GetStats() override;
 
-  void ClearStats() override;
+  bool ClearStats() override;
 
   void SetStream(void* stream) override {
 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index 7fdbbe1..5a330d5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -137,7 +137,7 @@
   return base_allocator_->GetStats();
 }
 
-void GPUDebugAllocator::ClearStats() { base_allocator_->ClearStats(); }
+bool GPUDebugAllocator::ClearStats() { return base_allocator_->ClearStats(); }
 
 bool GPUDebugAllocator::CheckHeader(void* ptr) {
   return CheckMask(stream_exec_, static_cast<char*>(ptr) - MASK_BYTES,
@@ -214,6 +214,8 @@
   return base_allocator_->GetStats();
 }
 
-void GPUNanResetAllocator::ClearStats() { base_allocator_->ClearStats(); }
+bool GPUNanResetAllocator::ClearStats() {
+  return base_allocator_->ClearStats();
+}
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0c085fe..aa982bc 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -44,7 +44,7 @@
   size_t AllocatedSize(const void* ptr) const override;
   int64 AllocationId(const void* ptr) const override;
   absl::optional<AllocatorStats> GetStats() override;
-  void ClearStats() override;
+  bool ClearStats() override;
 
   // For testing.
   bool CheckHeader(void* ptr);
@@ -72,7 +72,7 @@
   size_t RequestedSize(const void* ptr) const override;
   size_t AllocatedSize(const void* ptr) const override;
   absl::optional<AllocatorStats> GetStats() override;
-  void ClearStats() override;
+  bool ClearStats() override;
 
  private:
   Allocator* base_allocator_ = nullptr;  // owned
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 27d4f04..7314d29 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -84,13 +84,14 @@
     return stats_;
   }
 
-  void ClearStats() override {
+  bool ClearStats() override {
     mutex_lock l(mutex_);
     stats_.num_allocs = 0;
     stats_.peak_bytes_in_use = 0;
     stats_.largest_alloc_size = 0;
     stats_.bytes_in_use = 0;
     stats_.bytes_limit = 0;
+    return true;
   }
 
  private:
@@ -257,9 +258,10 @@
     return stats_;
   }
 
-  void ClearStats() override {
-    small_size_allocator_->ClearStats();
-    large_size_allocator_->ClearStats();
+  bool ClearStats() override {
+    bool stats_cleared = small_size_allocator_->ClearStats();
+    stats_cleared &= large_size_allocator_->ClearStats();
+    return stats_cleared;
   }
 
  private:
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 356c2f5..a65f263 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -146,7 +146,7 @@
     return a_->AllocatedSize(p);
   }
   absl::optional<AllocatorStats> GetStats() override { return a_->GetStats(); }
-  void ClearStats() override { a_->ClearStats(); }
+  bool ClearStats() override { return a_->ClearStats(); }
   ProcessState::MDMap* mm_;  // not owned
   Allocator* a_;             // not owned
   ProcessState::MemDesc md_;
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 38fd4be..7032a02 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -282,8 +282,12 @@
   // Fills in 'stats' with statistics collected by this allocator.
   virtual absl::optional<AllocatorStats> GetStats() { return absl::nullopt; }
 
-  // Clears the internal stats except for the `in_use` field.
-  virtual void ClearStats() {}
+  // If implemented, clears the internal stats except for the `in_use` fields
+  // and set the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
+  // true if implemented.
+  //
+  // REQUIRES: GetStats is overridden.
+  virtual bool ClearStats() TF_MUST_USE_RESULT { return false; }
 
   virtual void SetSafeFrontier(uint64 count) {}
 
diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc
index f1001e7..0f6c6cc 100644
--- a/tensorflow/core/framework/allocator_test.cc
+++ b/tensorflow/core/framework/allocator_test.cc
@@ -160,7 +160,7 @@
 
   CheckStats(a, 1025, 0, 1048576 * sizeof(double) + 1024 * sizeof(float),
              1048576 * sizeof(double));
-  a->ClearStats();
+  CHECK(a->ClearStats());
   CheckStats(a, 0, 0, 0, 0);
   DisableCPUAllocatorStats();
 }
diff --git a/tensorflow/core/framework/cpu_allocator_impl.cc b/tensorflow/core/framework/cpu_allocator_impl.cc
index f3d7fdc..dd24915 100644
--- a/tensorflow/core/framework/cpu_allocator_impl.cc
+++ b/tensorflow/core/framework/cpu_allocator_impl.cc
@@ -115,15 +115,18 @@
   }
 
   absl::optional<AllocatorStats> GetStats() override {
+    if (!cpu_allocator_collect_stats) return absl::nullopt;
     mutex_lock l(mu_);
     return stats_;
   }
 
-  void ClearStats() override {
+  bool ClearStats() override {
+    if (!cpu_allocator_collect_stats) return false;
     mutex_lock l(mu_);
     stats_.num_allocs = 0;
     stats_.peak_bytes_in_use = stats_.bytes_in_use;
     stats_.largest_alloc_size = 0;
+    return true;
   }
 
   size_t AllocatedSizeSlow(const void* ptr) const override {
diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc
index a758ffb..29ac76f 100644
--- a/tensorflow/core/framework/tracking_allocator.cc
+++ b/tensorflow/core/framework/tracking_allocator.cc
@@ -156,7 +156,7 @@
   return allocator_->GetStats();
 }
 
-void TrackingAllocator::ClearStats() { allocator_->ClearStats(); }
+bool TrackingAllocator::ClearStats() { return allocator_->ClearStats(); }
 
 std::tuple<size_t, size_t, size_t> TrackingAllocator::GetSizes() {
   size_t high_watermark;
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
index 7b5b391..ae3fc88 100644
--- a/tensorflow/core/framework/tracking_allocator.h
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -66,7 +66,7 @@
   size_t AllocatedSize(const void* ptr) const override;
   int64 AllocationId(const void* ptr) const override;
   absl::optional<AllocatorStats> GetStats() override;
-  void ClearStats() override;
+  bool ClearStats() override;
 
   // If the underlying allocator tracks allocation sizes, this returns
   // a tuple where the first value is the total number of bytes
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 2975293..aeb151b 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -458,7 +458,12 @@
       return Status(error::INVALID_ARGUMENT,
                     "Tracking allocation is not enabled.");
     }
-    allocator->ClearStats();
+    if (!allocator->ClearStats()) {
+      return Status(
+          error::INVALID_ARGUMENT,
+          absl::StrCat("Clearing allocation stats is not supported for ",
+                       device->name()));
+    }
   }
   return Status::OK();
 }
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index e6660d3..2e58533 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -1469,6 +1469,12 @@
     self.ensure_initialized()
     return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
 
+  def reset_memory_stats(self, dev):
+    """Resets the tracked memory stats for the device."""
+    self._initialize_physical_devices()
+    self.ensure_initialized()
+    pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev)
+
   def get_memory_growth(self, dev):
     """Get if memory growth is enabled for a PhysicalDevice."""
     self._initialize_physical_devices()
diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py
index 75664dd..0fcf664 100644
--- a/tensorflow/python/eager/context_test.py
+++ b/tensorflow/python/eager/context_test.py
@@ -129,13 +129,13 @@
   @test_util.run_gpu_only
   @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
   def testGetMemoryUsage(self):
-    array_ops.zeros([10]) # Allocate some memory on the GPU.
+    array_ops.zeros([10])  # Allocate some memory on the GPU.
     self.assertGreater(
         context.context().get_memory_info('GPU:0')['current'], 0)
 
   @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
   def testGetMemoryUsageCPU(self):
-    with self.assertRaisesRegex(ValueError, 'CPU does not support'):
+    with self.assertRaisesRegex(ValueError, 'Allocator stats not available'):
       context.context().get_memory_info('CPU:0')
 
   @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index 872ea0e..60b1b59 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -536,9 +536,10 @@
   ...   tf.config.experimental.get_memory_info('GPU:0')
 
   Currently returns the following keys:
-    `'current'`: The current memory used by the device, in bytes.
-    `'peak'`: The peak memory used by the device across the run of the program,
-        in bytes.
+    - `'current'`: The current memory used by the device, in bytes.
+    - `'peak'`: The peak memory used by the device across the run of the
+        program, in bytes. Can be reset with
+        `tf.config.experimental.reset_memory_stats`.
 
   More keys may be added in the future, including device-specific keys.
 
@@ -565,6 +566,42 @@
   return context.context().get_memory_info(device)
 
 
+@tf_export('config.experimental.reset_memory_stats')
+def reset_memory_stats(device):
+  """Resets the tracked memory stats for the chosen device.
+
+  This function sets the tracked peak memory for a device to the device's
+  current memory usage. This allows you to measure the peak memory usage for a
+  specific part of your program. For example:
+
+  >>> if tf.config.list_physical_devices('GPU'):
+  ...   # Sets the peak memory to the current memory.
+  ...   tf.config.experimental.reset_memory_stats('GPU:0')
+  ...   # Creates the first peak memory usage.
+  ...   x1 = tf.ones(1000 * 1000, dtype=tf.float64)
+  ...   del x1 # Frees the memory referenced by `x1`.
+  ...   peak1 = tf.config.experimental.get_memory_info('GPU:0')['peak']
+  ...   # Sets the peak memory to the current memory again.
+  ...   tf.config.experimental.reset_memory_stats('GPU:0')
+  ...   # Creates the second peak memory usage.
+  ...   x2 = tf.ones(1000 * 1000, dtype=tf.float32)
+  ...   del x2
+  ...   peak2 = tf.config.experimental.get_memory_info('GPU:0')['peak']
+  ...   assert peak2 < peak1  # tf.float32 consumes less memory than tf.float64.
+
+  Currently raises an exception for the CPU.
+
+  Args:
+    device: Device string to reset the memory stats, e.g. `"GPU:0"`. See
+      https://www.tensorflow.org/api_docs/python/tf/device for specifying device
+        strings.
+
+  Raises:
+    ValueError: Non-existent or CPU device specified.
+  """
+  context.context().reset_memory_stats(device)
+
+
 @deprecation.deprecated(
     None,
     "Use tf.config.experimental.get_memory_info(device)['current'] instead.")
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index dfbdb43..f2083f0 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -611,9 +611,9 @@
 
   @reset_eager
   def testGetMemoryInfoCPU(self):
-    with self.assertRaisesRegex(ValueError, 'CPU does not support'):
+    with self.assertRaisesRegex(ValueError, 'Allocator stats not available'):
       config.get_memory_info('CPU:0')
-    with self.assertRaisesRegex(ValueError, 'CPU does not support'):
+    with self.assertRaisesRegex(ValueError, 'Allocator stats not available'):
       config.get_memory_usage('CPU:0')
 
   @reset_eager
@@ -647,6 +647,39 @@
 
   @test_util.run_gpu_only
   @reset_eager
+  def testResetMemoryStats(self):
+    x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
+    config.reset_memory_stats('GPU:0')
+    info1 = config.get_memory_info('GPU:0')
+    self.assertGreaterEqual(info1['peak'], 4 * 1000 * 1000)
+    self.assertGreaterEqual(info1['peak'], info1['current'])
+    self.assertGreater(info1['current'], 0)
+
+    del x  # With CPython, causes tensor memory to be immediately freed
+    config.reset_memory_stats('GPU:0')
+    info2 = config.get_memory_info('GPU:0')
+    self.assertLess(info2['peak'], info1['peak'])
+
+  @reset_eager
+  def testResetMemoryStatsCPU(self):
+    with self.assertRaisesRegex(ValueError, 'Cannot reset memory stats'):
+      config.reset_memory_stats('CPU:0')
+
+  @reset_eager
+  def testResetMemoryStatsUnknownDevice(self):
+    with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
+      config.reset_memory_stats('unknown_device')
+
+  @test_util.run_gpu_only
+  @reset_eager
+  def testResetMemoryStatsAmbiguousDevice(self):
+    if len(config.list_physical_devices('GPU')) < 2:
+      self.skipTest('Need at least 2 GPUs')
+    with self.assertRaisesRegex(ValueError, 'Multiple devices'):
+      config.reset_memory_stats('GPU')
+
+  @test_util.run_gpu_only
+  @reset_eager
   def testGpuInvalidConfig(self):
     gpus = config.list_physical_devices('GPU')
     self.assertNotEqual(len(gpus), 0)
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index d0097f7..aee3ba7 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -224,6 +224,47 @@
   return output_tensor_handles;
 }
 
+tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
+  auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
+      tensorflow::InputTFE_Context(ctx));
+
+  tensorflow::DeviceNameUtils::ParsedName input_device_name;
+  if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
+                                                         &input_device_name)) {
+    tensorflow::ThrowValueError(
+        absl::StrFormat("Failed parsing device name: '%s'", device_name)
+            .c_str());
+  }
+
+  std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
+
+  tensorflow::Device* matched_device = nullptr;
+  for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
+    tensorflow::Device* device = devices[device_idx];
+
+    if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
+            input_device_name, device->parsed_name())) {
+      if (matched_device != nullptr) {
+        tensorflow::ThrowValueError(
+            absl::StrFormat("Multiple devices match the provided string "
+                            "'%s': '%s' and "
+                            "'%s' ",
+                            device_name, matched_device->name(), device->name())
+                .c_str());
+      }
+      matched_device = device;
+    }
+  }
+
+  if (matched_device == nullptr) {
+    tensorflow::ThrowValueError(
+        absl::StrFormat("No matching devices found for '%s'", device_name)
+            .c_str());
+  }
+
+  return matched_device;
+}
+
 // Packs multiple `EagerTensor`s of the same dtype and shape into one
 // `EagerTensor`.
 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
@@ -540,48 +581,8 @@
   });
 
   m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
-    auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
-        tensorflow::InputTFE_Context(ctx));
-
-    tensorflow::DeviceNameUtils::ParsedName input_device_name;
-    if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
-            device_name, &input_device_name)) {
-      tensorflow::ThrowValueError(
-          absl::StrFormat("Failed parsing device name: '%s'", device_name)
-              .c_str());
-    }
-
-    std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
-
-    tensorflow::Device* matched_device = nullptr;
-    for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
-      tensorflow::Device* device = devices[device_idx];
-
-      if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
-              input_device_name, device->parsed_name())) {
-        if (device->device_type() == tensorflow::DEVICE_CPU) {
-          tensorflow::ThrowValueError(
-              "CPU does not support getting allocator information");
-        }
-
-        if (matched_device != nullptr) {
-          tensorflow::ThrowValueError(
-              absl::StrFormat("Multiple devices matching the provided string "
-                              "'%s': '%s' and "
-                              "'%s' ",
-                              device_name, matched_device->name(),
-                              device->name())
-                  .c_str());
-        }
-        matched_device = device;
-      }
-    }
-
-    if (matched_device == nullptr) {
-      tensorflow::ThrowValueError(
-          absl::StrFormat("No matching devices found for '%s'", device_name)
-              .c_str());
-    }
+    tensorflow::Device* matched_device =
+        tensorflow::GetMatchedDevice(ctx, device_name);
 
     tensorflow::AllocatorAttributes attrs;
     tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
@@ -592,12 +593,27 @@
                                             {"peak", stats->peak_bytes_in_use}};
     }
 
-    tensorflow::ThrowTypeError(
+    tensorflow::ThrowValueError(
         absl::StrFormat("Allocator stats not available for device '%s'",
-                        matched_device->name())
+                        device_name)
             .c_str());
   });
 
+  m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
+    tensorflow::Device* matched_device =
+        tensorflow::GetMatchedDevice(ctx, device_name);
+
+    tensorflow::AllocatorAttributes attrs;
+    tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
+
+    if (!allocator->ClearStats()) {
+      tensorflow::ThrowValueError(
+          absl::StrFormat("Cannot reset memory stats for device '%s'",
+                          device_name)
+              .c_str());
+    }
+  });
+
   // XLA Eager Logic
   m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
   m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
index 7f3da24..c85e546 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
@@ -69,6 +69,10 @@
     argspec: "args=[\'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "reset_memory_stats"
+    argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "set_device_policy"
     argspec: "args=[\'device_policy\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
index 7f3da24..c85e546 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
@@ -69,6 +69,10 @@
     argspec: "args=[\'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "reset_memory_stats"
+    argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "set_device_policy"
     argspec: "args=[\'device_policy\'], varargs=None, keywords=None, defaults=None"
   }