[tf.data] Avoid contention in bytes-read metric collection. This change caches the `CounterCell*` for each source dataset type that records the number of bytes read. This avoids contention in the thread-safe map lookup each time a record is read. PiperOrigin-RevId: 302079362 Change-Id: I4efaeff8e0ffb3df2311228489ba107ce378ce91
diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc index f05f931..a26a678 100644 --- a/tensorflow/core/common_runtime/metrics.cc +++ b/tensorflow/core/common_runtime/metrics.cc
@@ -132,8 +132,8 @@ tf_data_autotune_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataBytesRead(const string& name, int64 num_bytes) { - tf_data_bytes_read_counter->GetCell(name)->IncrementBy(num_bytes); +monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name) { + return tf_data_bytes_read_counter->GetCell(name); } void RecordTFDataBytesFetched(int64 num_bytes) {
diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h index a5d43da..e95e049 100644 --- a/tensorflow/core/common_runtime/metrics.h +++ b/tensorflow/core/common_runtime/metrics.h
@@ -16,6 +16,7 @@ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_ +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -26,11 +27,11 @@ // The `name` argument identifies the Dataset type (e.g. "ParallelMap"). void RecordTFDataAutotune(const string& name); -// Records the number of bytes read from the filesystem by a tf.data.Dataset -// source. +// Returns a counter than can be used to record the number of bytes read from +// the filesystem by a tf.data.Dataset source. // // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). -void RecordTFDataBytesRead(const string& name, int64 num_bytes); +monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name); // Records the number of bytes fetched from tf.data.Dataset iterator. void RecordTFDataBytesFetched(int64 num_bytes);
diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 15bfeb0..468a222 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
@@ -138,8 +138,9 @@ string record; TF_RETURN_IF_ERROR( input_buffer_->ReadNBytes(dataset()->record_bytes_, &record)); - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); + bytes_counter->IncrementBy(dataset()->record_bytes_); // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); @@ -251,6 +252,8 @@ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); mutex_lock l(mu_); do { // We are currently processing a file, so try to read the next record. @@ -262,8 +265,7 @@ tstring record; TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes( dataset()->record_bytes_, &record)); - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + bytes_counter->IncrementBy(dataset()->record_bytes_); // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); @@ -277,8 +279,7 @@ Status s = buffered_input_stream_->ReadNBytes( dataset()->record_bytes_, &record); if (s.ok()) { - metrics::RecordTFDataBytesRead(kDatasetType, - dataset()->record_bytes_); + bytes_counter->IncrementBy(dataset()->record_bytes_); lookahead_cache_.append(record); StringPiece lookahead_cache_view(lookahead_cache_); record = tstring(
diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index c2c3190..dc193f5 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc
@@ -105,12 +105,13 @@ if (s.ok()) { // Produce the line as output. - metrics::RecordTFDataBytesRead( - name_utils::OpName(TextLineDatasetOp::kDatasetType), - line_contents.size()); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter( + name_utils::OpName(TextLineDatasetOp::kDatasetType)); + bytes_counter->IncrementBy(line_contents.size()); out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar<tstring>()() = std::move(line_contents); + out_tensors->back().scalar<tstring>()() = line_contents; *end_of_sequence = false; return Status::OK(); } else if (!errors::IsOutOfRange(s)) {
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index a72d05c..94d523b 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc
@@ -119,8 +119,10 @@ Status s = reader_->ReadRecord(&out_tensors->back().scalar<tstring>()()); if (s.ok()) { - metrics::RecordTFDataBytesRead( - kDatasetType, out_tensors->back().scalar<tstring>()().size()); + static monitoring::CounterCell* bytes_counter = + metrics::GetTFDataBytesReadCounter(kDatasetType); + bytes_counter->IncrementBy( + out_tensors->back().scalar<tstring>()().size()); *end_of_sequence = false; return Status::OK(); }