Extend the Keras TensorBoard callback to optionally display the global steps per second happening during training.
PiperOrigin-RevId: 344210886
Change-Id: I36854954709ec29944696629ec67ef453e4dd2a7
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index f3169d2..493e738 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -669,7 +669,7 @@
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
- are prefixed with `val_`. For training epoch, the values of the
+ are prefixed with `val_`. For training epoch, the values of the
`Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
"""
@@ -2002,6 +2002,8 @@
can become quite large when write_graph is set to True.
write_images: whether to write model weights to visualize as image in
TensorBoard.
+ write_steps_per_second: whether to log the training steps per second into
+ Tensorboard. This supports both epoch and batch frequency logging.
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
writes the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `1000`, the
@@ -2097,6 +2099,7 @@
histogram_freq=0,
write_graph=True,
write_images=False,
+ write_steps_per_second=False,
update_freq='epoch',
profile_batch=2,
embeddings_freq=0,
@@ -2110,12 +2113,16 @@
self.histogram_freq = histogram_freq
self.write_graph = write_graph
self.write_images = write_images
+ self.write_steps_per_second = write_steps_per_second
self.update_freq = 1 if update_freq == 'batch' else update_freq
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
self._init_profile_batch(profile_batch)
self._epoch = 0
self._global_train_batch = 0
+ self._previous_epoch_iterations = 0
+ self._train_accumulated_time = 0
+ self._batch_start_time = 0
# Lazily initialized in order to avoid creating event files when
# not needed.
@@ -2336,6 +2343,8 @@
def on_train_begin(self, logs=None):
self._global_train_batch = 0
+ self._previous_epoch_iterations = 0
+ self._train_accumulated_time = 0
self._push_writer(self._train_writer, self._train_step)
def on_train_end(self, logs=None):
@@ -2358,6 +2367,8 @@
def on_train_batch_begin(self, batch, logs=None):
self._global_train_batch += 1
+ if self.write_steps_per_second:
+ self._batch_start_time = time.time()
if not self._should_trace:
return
@@ -2368,6 +2379,10 @@
if self._should_write_train_graph:
self._write_keras_model_train_graph()
self._should_write_train_graph = False
+ if self.write_steps_per_second:
+ batch_run_time = time.time() - self._batch_start_time
+ self._train_accumulated_time += batch_run_time
+ summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time)
if not self._should_trace:
return
@@ -2377,6 +2392,9 @@
def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
self._epoch = epoch
+ if self.write_steps_per_second:
+ self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
+ self._train_accumulated_time = 0
def on_epoch_end(self, epoch, logs=None):
"""Runs metrics and histogram summaries at epoch end."""
@@ -2410,6 +2428,12 @@
logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
return logs
+ def _compute_steps_per_second(self):
+ current_iteration = self.model.optimizer.iterations.numpy()
+ steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
+ (self._train_accumulated_time))
+ return steps_per_second
+
def _log_epoch_metrics(self, epoch, logs):
"""Writes epoch metrics out as scalar summaries.
@@ -2423,6 +2447,8 @@
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
train_logs = self._collect_learning_rate(train_logs)
+ if self.write_steps_per_second:
+ train_logs['steps_per_second'] = self._compute_steps_per_second()
with summary_ops_v2.record_if(True):
if train_logs:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 538f981..686e10c 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -2029,6 +2029,37 @@
},
)
+ def test_TensorBoard_global_step(self):
+ model = self._get_model(compile_model=False)
+ opt = gradient_descent.SGD(learning_rate_schedule.CosineDecay(0.01, 1))
+ model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
+
+ x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
+
+ model.fit(
+ x,
+ y,
+ batch_size=2,
+ epochs=2,
+ callbacks=[
+ keras.callbacks.TensorBoard(
+ self.logdir, update_freq=1, write_steps_per_second=True)
+ ])
+
+ summary_file = list_summaries(self.logdir)
+ self.assertEqual(
+ summary_file.scalars,
+ {
+ _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
+ _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
+ _ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'),
+ _ObservedSummary(
+ logdir=self.train_dir, tag='epoch_steps_per_second'),
+ _ObservedSummary(
+ logdir=self.train_dir, tag='batch_steps_per_second'),
+ },
+ )
+
def test_TensorBoard_weight_histograms(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
index 51d6901..f0e3c04 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt
@@ -6,7 +6,7 @@
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'write_steps_per_second\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
}
member_method {
name: "on_batch_begin"