Metal benchmark shows consumed memory size.
PiperOrigin-RevId: 410732285
Change-Id: I0c6d5c24108e5af75ad28f7fdb1e6159d71d7d44
diff --git a/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm b/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
index 73f2b24..89c1b3e 100644
--- a/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
+++ b/tensorflow/lite/delegates/gpu/metal/benchmarking/main.mm
@@ -71,6 +71,9 @@
inference_context.Profile(device, &profiling_info);
std::cout << profiling_info.GetDetailedReport() << std::endl;
}
+ uint64_t mem_bytes = inference_context.GetIntermediateTensorsSize();
+ std::cout << "Memory for intermediate tensors - " << mem_bytes / 1024.0 / 1024.0 << " MB"
+ << std::endl;
const std::string precision_str = use_fp16 ? "FP16" : "FP32";
std::cout << "Measuring started: (" << num_tests << " tests, " << iterations
<< " iterations every test, " << precision_str << " precision)" << std::endl;
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
index 25f9e23..4170553 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc
@@ -761,6 +761,18 @@
}
}
+uint64_t InferenceContext::GetIntermediateTensorsSize() const {
+ uint64_t total_memory = 0;
+ for (const auto& t : strong_shape_tensors_) {
+ total_memory += t.second.GetMemorySizeInBytes();
+ }
+ for (const auto& b : shared_buffers_) {
+ total_memory += [b length];
+ }
+
+ return total_memory;
+}
+
void InferenceContext::EncodeWithCommandBuffer(
id<MTLCommandBuffer> command_buffer) {
for (int i = 0; i < nodes_.size(); ++i) {
diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h
index 3ed7c67..1a062f7 100644
--- a/tensorflow/lite/delegates/gpu/metal/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h
@@ -133,6 +133,9 @@
int flush_period);
void Profile(id<MTLDevice> device, ProfilingInfo* result);
+ // Returns size in bytes for all intermediate(runtime) tensors that owned by
+ // this inference context. Do not include constant tensors.
+ uint64_t GetIntermediateTensorsSize() const;
private:
enum class TensorMemoryType {