Metal benchmark shows consumed memory size.

PiperOrigin-RevId: 410732285
Change-Id: I0c6d5c24108e5af75ad28f7fdb1e6159d71d7d44
diff --git a/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm b/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
index 73f2b24..89c1b3e 100644
--- a/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
+++ b/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
@@ -71,6 +71,9 @@
     inference_context.Profile(device, &profiling_info);
     std::cout << profiling_info.GetDetailedReport() << std::endl;
   }
+  uint64_t mem_bytes = inference_context.GetIntermediateTensorsSize();
+  std::cout << "Memory for intermediate tensors - " << mem_bytes / 1024.0 / 1024.0 << " MB"
+            << std::endl;
   const std::string precision_str = use_fp16 ? "FP16" : "FP32";
   std::cout << "Measuring started: (" << num_tests << " tests, " << iterations
       << " iterations every test, " << precision_str << " precision)" << std::endl;
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
index 25f9e23..4170553 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -761,6 +761,18 @@
   }
 }
 
+uint64_t InferenceContext::GetIntermediateTensorsSize() const {
+  uint64_t total_memory = 0;
+  for (const auto& t : strong_shape_tensors_) {
+    total_memory += t.second.GetMemorySizeInBytes();
+  }
+  for (const auto& b : shared_buffers_) {
+    total_memory += [b length];
+  }
+
+  return total_memory;
+}
+
 void InferenceContext::EncodeWithCommandBuffer(
     id<MTLCommandBuffer> command_buffer) {
   for (int i = 0; i < nodes_.size(); ++i) {
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h
index 3ed7c67..1a062f7 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h
@@ -133,6 +133,9 @@
                               int flush_period);
 
   void Profile(id<MTLDevice> device, ProfilingInfo* result);
+  // Returns size in bytes for all intermediate(runtime) tensors that owned by
+  // this inference context. Do not include constant tensors.
+  uint64_t GetIntermediateTensorsSize() const;
 
  private:
   enum class TensorMemoryType {