#tf-data-service Add a metric for cross-trainer cache size in bytes.

Also moves metrics collection out of mutex scopes.

PiperOrigin-RevId: 445526260
diff --git a/tensorflow/core/data/service/multi_trainer_cache.h b/tensorflow/core/data/service/multi_trainer_cache.h
index 21f3a4b..d4d2b58 100644
--- a/tensorflow/core/data/service/multi_trainer_cache.h
+++ b/tensorflow/core/data/service/multi_trainer_cache.h
@@ -120,6 +120,14 @@
   bool IsCancelled() const;
 
  private:
+  struct CacheQueryResult {
+    std::shared_ptr<const ElementType> element;
+    bool cache_hit;
+  };
+
+  // Returns the next element and metrics about this query.
+  StatusOr<CacheQueryResult> GetCacheQueryResult(const std::string& trainer_id);
+
   // Returns true if element is ready for `trainer_id`. An element is ready if
   // other trainers have read the data and the data remains in the cache. If the
   // data is not ready, one of the trainers need to extend the cache.
@@ -140,6 +148,9 @@
   // `new_element_size_bytes` is the size of the new element being inserted.
   void FreeSpace(size_t new_element_size_bytes);
 
+  // Records the cache hit rate and cache size.
+  void RecordMetrics(const CacheQueryResult& result);
+
   // Maximum cache size in bytes.
   const size_t max_cache_size_bytes_;
 
@@ -189,15 +200,25 @@
         "tf.data service multi-trainer cache trainer ID must be non-empty.");
   }
 
+  TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id));
+  RecordMetrics(result);
+  return result.element;
+}
+
+template <class ElementType>
+StatusOr<typename MultiTrainerCache<ElementType>::CacheQueryResult>
+MultiTrainerCache<ElementType>::GetCacheQueryResult(
+    const std::string& trainer_id) {
   bool should_extend_cache = false;
   while (true) {
     {
       mutex_lock l(mu_);
       TF_RETURN_IF_ERROR(status_);
       if (IsElementReady(trainer_id)) {
-        metrics::RecordTFDataServiceMultiTrainerCacheQuery(
-            /*cache_hit=*/!should_extend_cache);
-        return GetElement(trainer_id);
+        TF_ASSIGN_OR_RETURN(std::shared_ptr<const ElementType> element,
+                            GetElement(trainer_id));
+        return CacheQueryResult{element,
+                                /*is_cache_hit=*/!should_extend_cache};
       }
 
       // Extends the cache or waits for another thread to extend the cache. When
@@ -313,6 +334,19 @@
   mutex_lock l(mu_);
   return !status_.ok();
 }
+
+template <class ElementType>
+void MultiTrainerCache<ElementType>::RecordMetrics(
+    const CacheQueryResult& result) {
+  metrics::RecordTFDataServiceMultiTrainerCacheQuery(result.cache_hit);
+  size_t cache_size_bytes = 0;
+  {
+    mutex_lock l(mu_);
+    cache_size_bytes = cache_size_bytes_;
+  }
+  metrics::RecordTFDataServiceMultiTrainerCacheSizeBytes(cache_size_bytes);
+}
+
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/data/service/multi_trainer_cache_test.cc b/tensorflow/core/data/service/multi_trainer_cache_test.cc
index 2dbf3ef..a9ae992 100644
--- a/tensorflow/core/data/service/multi_trainer_cache_test.cc
+++ b/tensorflow/core/data/service/multi_trainer_cache_test.cc
@@ -264,6 +264,28 @@
   EXPECT_EQ(cell_reader.Read("false"), 10);
 }
 
+TEST(MultiTrainerCacheTest, CacheSizeMetrics) {
+  CellReader<int64_t> cell_reader(
+      "/tensorflow/data/service/multi_trainer_cache_size_bytes");
+
+  const size_t num_elements = 5;
+  MultiTrainerCache<int64_t> cache(
+      /*max_cache_size_bytes=*/num_elements * sizeof(int64_t),
+      std::make_unique<InfiniteRange>());
+
+  for (size_t i = 0; i < num_elements; ++i) {
+    EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(i)));
+    EXPECT_EQ(cell_reader.Read(), (i + 1) * sizeof(int64_t));
+  }
+
+  // The cache size does not increase after reaching `num_elements`.
+  for (size_t i = 0; i < 100; ++i) {
+    EXPECT_THAT(cache.Get("Trainer 1"),
+                IsOkAndHolds(Pointee(num_elements + i)));
+    EXPECT_EQ(cell_reader.Read(), 5 * sizeof(int64_t));
+  }
+}
+
 TEST(MultiTrainerCacheTest, ConcurrentReaders) {
   size_t num_trainers = 10;
   size_t num_elements_to_read = 200;
diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc
index 00d9a1a..5d6e701 100644
--- a/tensorflow/core/framework/metrics.cc
+++ b/tensorflow/core/framework/metrics.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/framework/metrics.h"
 
+#include <cstdint>
 #include <string>
 
 #include "absl/strings/str_cat.h"
@@ -157,6 +158,11 @@
         "hit or miss.",
         "cache_hit");
 
+auto* tf_data_service_multi_trainer_cache_size_bytes =
+    monitoring::Gauge<int64_t, 0>::New(
+        "/tensorflow/data/service/multi_trainer_cache_size_bytes",
+        "tf.data service multi-client cache memory usage in bytes.");
+
 auto* tf_data_filename_counter = monitoring::Counter<2>::New(
     "/tensorflow/data/filename", "The file name read by a tf.data Dataset.",
     "name", "filename");
@@ -383,6 +389,11 @@
       ->IncrementBy(1);
 }
 
+void RecordTFDataServiceMultiTrainerCacheSizeBytes(size_t bytes) {
+  tf_data_service_multi_trainer_cache_size_bytes->GetCell()->Set(
+      static_cast<int64_t>(bytes));
+}
+
 void RecordTFDataFilename(const string& name, const string& filename) {
   tf_data_filename_counter->GetCell(name, filename)->IncrementBy(1);
 }
diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h
index 55342dd..abb01ec 100644
--- a/tensorflow/core/framework/metrics.h
+++ b/tensorflow/core/framework/metrics.h
@@ -126,6 +126,9 @@
 // Records tf.data service multi-trainer cache queries.
 void RecordTFDataServiceMultiTrainerCacheQuery(bool cache_hit);
 
+// Records tf.data service multi-trainer cache memory usage in bytes.
+void RecordTFDataServiceMultiTrainerCacheSizeBytes(size_t bytes);
+
 // Records the file name read by a tf.data Dataset.
 //
 // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset").