[tf.data] Avoid contention in bytes-read metric collection.

This change caches the `CounterCell*` for each source dataset type that
records the number of bytes read. This avoids contention in the thread-safe map
lookup each time a record is read.

PiperOrigin-RevId: 302079362
Change-Id: I4efaeff8e0ffb3df2311228489ba107ce378ce91
diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc
index f05f931..a26a678 100644
--- a/tensorflow/core/common_runtime/metrics.cc
+++ b/tensorflow/core/common_runtime/metrics.cc
@@ -132,8 +132,8 @@
   tf_data_autotune_counter->GetCell(name)->IncrementBy(1);
 }
 
-void RecordTFDataBytesRead(const string& name, int64 num_bytes) {
-  tf_data_bytes_read_counter->GetCell(name)->IncrementBy(num_bytes);
+monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name) {
+  return tf_data_bytes_read_counter->GetCell(name);
 }
 
 void RecordTFDataBytesFetched(int64 num_bytes) {
diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h
index a5d43da..e95e049 100644
--- a/tensorflow/core/common_runtime/metrics.h
+++ b/tensorflow/core/common_runtime/metrics.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
 #define TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
 
+#include "tensorflow/core/lib/monitoring/counter.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
@@ -26,11 +27,11 @@
 // The `name` argument identifies the Dataset type (e.g. "ParallelMap").
 void RecordTFDataAutotune(const string& name);
 
-// Records the number of bytes read from the filesystem by a tf.data.Dataset
-// source.
+// Returns a counter than can be used to record the number of bytes read from
+// the filesystem by a tf.data.Dataset source.
 //
 // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset").
-void RecordTFDataBytesRead(const string& name, int64 num_bytes);
+monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name);
 
 // Records the number of bytes fetched from tf.data.Dataset iterator.
 void RecordTFDataBytesFetched(int64 num_bytes);
diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
index 15bfeb0..468a222 100644
--- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
+++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
@@ -138,8 +138,9 @@
             string record;
             TF_RETURN_IF_ERROR(
                 input_buffer_->ReadNBytes(dataset()->record_bytes_, &record));
-            metrics::RecordTFDataBytesRead(kDatasetType,
-                                           dataset()->record_bytes_);
+            static monitoring::CounterCell* bytes_counter =
+                metrics::GetTFDataBytesReadCounter(kDatasetType);
+            bytes_counter->IncrementBy(dataset()->record_bytes_);
 
             // Produce the record as output.
             Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
@@ -251,6 +252,8 @@
     Status GetNextInternal(IteratorContext* ctx,
                            std::vector<Tensor>* out_tensors,
                            bool* end_of_sequence) override {
+      static monitoring::CounterCell* bytes_counter =
+          metrics::GetTFDataBytesReadCounter(kDatasetType);
       mutex_lock l(mu_);
       do {
         // We are currently processing a file, so try to read the next record.
@@ -262,8 +265,7 @@
               tstring record;
               TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
                   dataset()->record_bytes_, &record));
-              metrics::RecordTFDataBytesRead(kDatasetType,
-                                             dataset()->record_bytes_);
+              bytes_counter->IncrementBy(dataset()->record_bytes_);
 
               // Produce the record as output.
               Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
@@ -277,8 +279,7 @@
             Status s = buffered_input_stream_->ReadNBytes(
                 dataset()->record_bytes_, &record);
             if (s.ok()) {
-              metrics::RecordTFDataBytesRead(kDatasetType,
-                                             dataset()->record_bytes_);
+              bytes_counter->IncrementBy(dataset()->record_bytes_);
               lookahead_cache_.append(record);
               StringPiece lookahead_cache_view(lookahead_cache_);
               record = tstring(
diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc
index c2c3190..dc193f5 100644
--- a/tensorflow/core/kernels/data/text_line_dataset_op.cc
+++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc
@@ -105,12 +105,13 @@
 
           if (s.ok()) {
             // Produce the line as output.
-            metrics::RecordTFDataBytesRead(
-                name_utils::OpName(TextLineDatasetOp::kDatasetType),
-                line_contents.size());
+            static monitoring::CounterCell* bytes_counter =
+                metrics::GetTFDataBytesReadCounter(
+                    name_utils::OpName(TextLineDatasetOp::kDatasetType));
+            bytes_counter->IncrementBy(line_contents.size());
             out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                       TensorShape({}));
-            out_tensors->back().scalar<tstring>()() = std::move(line_contents);
+            out_tensors->back().scalar<tstring>()() = line_contents;
             *end_of_sequence = false;
             return Status::OK();
           } else if (!errors::IsOutOfRange(s)) {
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc
index a72d05c..94d523b 100644
--- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc
@@ -119,8 +119,10 @@
           Status s =
               reader_->ReadRecord(&out_tensors->back().scalar<tstring>()());
           if (s.ok()) {
-            metrics::RecordTFDataBytesRead(
-                kDatasetType, out_tensors->back().scalar<tstring>()().size());
+            static monitoring::CounterCell* bytes_counter =
+                metrics::GetTFDataBytesReadCounter(kDatasetType);
+            bytes_counter->IncrementBy(
+                out_tensors->back().scalar<tstring>()().size());
             *end_of_sequence = false;
             return Status::OK();
           }