Extend the Keras TensorBoard callback to optionally display the global steps per second happening during training.

PiperOrigin-RevId: 344210886
Change-Id: I36854954709ec29944696629ec67ef453e4dd2a7
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index f3169d2..493e738 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -669,7 +669,7 @@
         epoch: Integer, index of epoch.
         logs: Dict, metric results for this training epoch, and for the
           validation epoch if validation is performed. Validation result keys
-          are prefixed with `val_`. For training epoch, the values of the  
+          are prefixed with `val_`. For training epoch, the values of the
          `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
     """
 
@@ -2002,6 +2002,8 @@
         can become quite large when write_graph is set to True.
       write_images: whether to write model weights to visualize as image in
         TensorBoard.
+      write_steps_per_second: whether to log the training steps per second into
+        Tensorboard. This supports both epoch and batch frequency logging.
       update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
         writes the losses and metrics to TensorBoard after each batch. The same
         applies for `'epoch'`. If using an integer, let's say `1000`, the
@@ -2097,6 +2099,7 @@
                histogram_freq=0,
                write_graph=True,
                write_images=False,
+               write_steps_per_second=False,
                update_freq='epoch',
                profile_batch=2,
                embeddings_freq=0,
@@ -2110,12 +2113,16 @@
     self.histogram_freq = histogram_freq
     self.write_graph = write_graph
     self.write_images = write_images
+    self.write_steps_per_second = write_steps_per_second
     self.update_freq = 1 if update_freq == 'batch' else update_freq
     self.embeddings_freq = embeddings_freq
     self.embeddings_metadata = embeddings_metadata
     self._init_profile_batch(profile_batch)
     self._epoch = 0
     self._global_train_batch = 0
+    self._previous_epoch_iterations = 0
+    self._train_accumulated_time = 0
+    self._batch_start_time = 0
 
     # Lazily initialized in order to avoid creating event files when
     # not needed.
@@ -2336,6 +2343,8 @@
 
   def on_train_begin(self, logs=None):
     self._global_train_batch = 0
+    self._previous_epoch_iterations = 0
+    self._train_accumulated_time = 0
     self._push_writer(self._train_writer, self._train_step)
 
   def on_train_end(self, logs=None):
@@ -2358,6 +2367,8 @@
 
   def on_train_batch_begin(self, batch, logs=None):
     self._global_train_batch += 1
+    if self.write_steps_per_second:
+      self._batch_start_time = time.time()
     if not self._should_trace:
       return
 
@@ -2368,6 +2379,10 @@
     if self._should_write_train_graph:
       self._write_keras_model_train_graph()
       self._should_write_train_graph = False
+    if self.write_steps_per_second:
+      batch_run_time = time.time() - self._batch_start_time
+      self._train_accumulated_time += batch_run_time
+      summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time)
     if not self._should_trace:
       return
 
@@ -2377,6 +2392,9 @@
   def on_epoch_begin(self, epoch, logs=None):
     # Keeps track of epoch for profiling.
     self._epoch = epoch
+    if self.write_steps_per_second:
+      self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
+      self._train_accumulated_time = 0
 
   def on_epoch_end(self, epoch, logs=None):
     """Runs metrics and histogram summaries at epoch end."""
@@ -2410,6 +2428,12 @@
       logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
     return logs
 
+  def _compute_steps_per_second(self):
+    current_iteration = self.model.optimizer.iterations.numpy()
+    steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
+                        (self._train_accumulated_time))
+    return steps_per_second
+
   def _log_epoch_metrics(self, epoch, logs):
     """Writes epoch metrics out as scalar summaries.
 
@@ -2423,6 +2447,8 @@
     train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
     val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
     train_logs = self._collect_learning_rate(train_logs)
+    if self.write_steps_per_second:
+      train_logs['steps_per_second'] = self._compute_steps_per_second()
 
     with summary_ops_v2.record_if(True):
       if train_logs:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 538f981..686e10c 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -2029,6 +2029,37 @@
         },
     )
 
+  def test_TensorBoard_global_step(self):
+    model = self._get_model(compile_model=False)
+    opt = gradient_descent.SGD(learning_rate_schedule.CosineDecay(0.01, 1))
+    model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
+
+    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
+
+    model.fit(
+        x,
+        y,
+        batch_size=2,
+        epochs=2,
+        callbacks=[
+            keras.callbacks.TensorBoard(
+                self.logdir, update_freq=1, write_steps_per_second=True)
+        ])
+
+    summary_file = list_summaries(self.logdir)
+    self.assertEqual(
+        summary_file.scalars,
+        {
+            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
+            _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
+            _ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'),
+            _ObservedSummary(
+                logdir=self.train_dir, tag='epoch_steps_per_second'),
+            _ObservedSummary(
+                logdir=self.train_dir, tag='batch_steps_per_second'),
+        },
+    )
+
   def test_TensorBoard_weight_histograms(self):
     model = self._get_model()
     x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
index 51d6901..f0e3c04 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
@@ -6,7 +6,7 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'write_steps_per_second\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
   }
   member_method {
     name: "on_batch_begin"