#tf-data-service Add a metric for cross-trainer cache size in bytes.
Also moves metrics collection out of mutex scopes.
PiperOrigin-RevId: 445526260
diff --git a/tensorflow/core/data/service/multi_trainer_cache.h b/tensorflow/core/data/service/multi_trainer_cache.h
index 21f3a4b..d4d2b58 100644
--- a/tensorflow/core/data/service/multi_trainer_cache.h
+++ b/tensorflow/core/data/service/multi_trainer_cache.h
@@ -120,6 +120,14 @@
bool IsCancelled() const;
private:
+ struct CacheQueryResult {
+ std::shared_ptr<const ElementType> element;
+ bool cache_hit;
+ };
+
+ // Returns the next element and metrics about this query.
+ StatusOr<CacheQueryResult> GetCacheQueryResult(const std::string& trainer_id);
+
// Returns true if element is ready for `trainer_id`. An element is ready if
// other trainers have read the data and the data remains in the cache. If the
// data is not ready, one of the trainers need to extend the cache.
@@ -140,6 +148,9 @@
// `new_element_size_bytes` is the size of the new element being inserted.
void FreeSpace(size_t new_element_size_bytes);
+ // Records the cache hit rate and cache size.
+ void RecordMetrics(const CacheQueryResult& result);
+
// Maximum cache size in bytes.
const size_t max_cache_size_bytes_;
@@ -189,15 +200,25 @@
"tf.data service multi-trainer cache trainer ID must be non-empty.");
}
+ TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id));
+ RecordMetrics(result);
+ return result.element;
+}
+
+template <class ElementType>
+StatusOr<typename MultiTrainerCache<ElementType>::CacheQueryResult>
+MultiTrainerCache<ElementType>::GetCacheQueryResult(
+ const std::string& trainer_id) {
bool should_extend_cache = false;
while (true) {
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(status_);
if (IsElementReady(trainer_id)) {
- metrics::RecordTFDataServiceMultiTrainerCacheQuery(
- /*cache_hit=*/!should_extend_cache);
- return GetElement(trainer_id);
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const ElementType> element,
+ GetElement(trainer_id));
+ return CacheQueryResult{element,
+ /*is_cache_hit=*/!should_extend_cache};
}
// Extends the cache or waits for another thread to extend the cache. When
@@ -313,6 +334,19 @@
mutex_lock l(mu_);
return !status_.ok();
}
+
+template <class ElementType>
+void MultiTrainerCache<ElementType>::RecordMetrics(
+ const CacheQueryResult& result) {
+ metrics::RecordTFDataServiceMultiTrainerCacheQuery(result.cache_hit);
+ size_t cache_size_bytes = 0;
+ {
+ mutex_lock l(mu_);
+ cache_size_bytes = cache_size_bytes_;
+ }
+ metrics::RecordTFDataServiceMultiTrainerCacheSizeBytes(cache_size_bytes);
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/data/service/multi_trainer_cache_test.cc b/tensorflow/core/data/service/multi_trainer_cache_test.cc
index 2dbf3ef..a9ae992 100644
--- a/tensorflow/core/data/service/multi_trainer_cache_test.cc
+++ b/tensorflow/core/data/service/multi_trainer_cache_test.cc
@@ -264,6 +264,28 @@
EXPECT_EQ(cell_reader.Read("false"), 10);
}
+TEST(MultiTrainerCacheTest, CacheSizeMetrics) {
+ CellReader<int64_t> cell_reader(
+ "/tensorflow/data/service/multi_trainer_cache_size_bytes");
+
+ const size_t num_elements = 5;
+ MultiTrainerCache<int64_t> cache(
+ /*max_cache_size_bytes=*/num_elements * sizeof(int64_t),
+ std::make_unique<InfiniteRange>());
+
+ for (size_t i = 0; i < num_elements; ++i) {
+ EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(i)));
+ EXPECT_EQ(cell_reader.Read(), (i + 1) * sizeof(int64_t));
+ }
+
+ // The cache size does not increase after reaching `num_elements`.
+ for (size_t i = 0; i < 100; ++i) {
+ EXPECT_THAT(cache.Get("Trainer 1"),
+ IsOkAndHolds(Pointee(num_elements + i)));
+ EXPECT_EQ(cell_reader.Read(), 5 * sizeof(int64_t));
+ }
+}
+
TEST(MultiTrainerCacheTest, ConcurrentReaders) {
size_t num_trainers = 10;
size_t num_elements_to_read = 200;
diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc
index 00d9a1a..5d6e701 100644
--- a/tensorflow/core/framework/metrics.cc
+++ b/tensorflow/core/framework/metrics.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/framework/metrics.h"
+#include <cstdint>
#include <string>
#include "absl/strings/str_cat.h"
@@ -157,6 +158,11 @@
"hit or miss.",
"cache_hit");
+auto* tf_data_service_multi_trainer_cache_size_bytes =
+ monitoring::Gauge<int64_t, 0>::New(
+ "/tensorflow/data/service/multi_trainer_cache_size_bytes",
+ "tf.data service multi-client cache memory usage in bytes.");
+
auto* tf_data_filename_counter = monitoring::Counter<2>::New(
"/tensorflow/data/filename", "The file name read by a tf.data Dataset.",
"name", "filename");
@@ -383,6 +389,11 @@
->IncrementBy(1);
}
+void RecordTFDataServiceMultiTrainerCacheSizeBytes(size_t bytes) {
+ tf_data_service_multi_trainer_cache_size_bytes->GetCell()->Set(
+ static_cast<int64_t>(bytes));
+}
+
void RecordTFDataFilename(const string& name, const string& filename) {
tf_data_filename_counter->GetCell(name, filename)->IncrementBy(1);
}
diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h
index 55342dd..abb01ec 100644
--- a/tensorflow/core/framework/metrics.h
+++ b/tensorflow/core/framework/metrics.h
@@ -126,6 +126,9 @@
// Records tf.data service multi-trainer cache queries.
void RecordTFDataServiceMultiTrainerCacheQuery(bool cache_hit);
+// Records tf.data service multi-trainer cache memory usage in bytes.
+void RecordTFDataServiceMultiTrainerCacheSizeBytes(size_t bytes);
+
// Records the file name read by a tf.data Dataset.
//
// The `name` argument identifies the Dataset type (e.g. "TFRecordDataset").