Improve TensorBoard Callback documentation.
PiperOrigin-RevId: 335705948
Change-Id: Ia125644d34a7e6cfb9e9dbc28c58a435aed97961
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index b44b6ef..474cb8a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -1944,30 +1944,6 @@
You can find more information about TensorBoard
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
- Example (Basic):
-
- ```python
- tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
- model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
- # run the tensorboard command to view the visualizations.
- ```
-
- Example (Profile):
-
- ```python
- # profile a single batch, e.g. the 5th batch.
- tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
- profile_batch=5)
- model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
- # Now run the tensorboard command to view the visualizations (profile plugin).
-
- # profile a range of batches, e.g. from 10 to 20.
- tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
- profile_batch='10,20')
- model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
- # Now run the tensorboard command to view the visualizations (profile plugin).
- ```
-
Arguments:
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard.
@@ -1999,8 +1975,72 @@
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
- Raises:
- ValueError: If histogram_freq is set and no validation data is provided.
+ Examples:
+
+ Basic usage:
+
+ ```python
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
+ model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
+ # Then run the tensorboard command to view the visualizations.
+ ```
+
+ Custom batch-level summaries in a subclassed Model:
+
+ ```python
+ class MyModel(tf.keras.Model):
+
+ def build(self, _):
+ self.dense = tf.keras.layers.Dense(10)
+
+ def call(self, x):
+ outputs = self.dense(x)
+ tf.summary.histogram('outputs', outputs)
+ return outputs
+
+ model = MyModel()
+ model.compile('sgd', 'mse')
+
+ # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
+ # In addition to any `tf.summary` contained in `Model.call`, metrics added in
+ # `Model.compile` will be logged every N batches.
+ tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
+ model.fit(x_train, y_train, callbacks=[tb_callback])
+ ```
+
+ Custom batch-level summaries in a Functional API Model:
+
+ ```python
+ def my_summary(x):
+ tf.summary.histogram('x', x)
+ return x
+
+ inputs = tf.keras.Input(10)
+ x = tf.keras.layers.Dense(10)(inputs)
+ outputs = tf.keras.layers.Lambda(my_summary)(x)
+ model = tf.keras.Model(inputs, outputs)
+ model.compile('sgd', 'mse')
+
+ # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
+ # In addition to any `tf.summary` contained in `Model.call`, metrics added in
+ # `Model.compile` will be logged every N batches.
+ tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
+ model.fit(x_train, y_train, callbacks=[tb_callback])
+ ```
+
+ Profiling:
+
+ ```python
+ # Profile a single batch, e.g. the 5th batch.
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir='./logs', profile_batch=5)
+ model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
+
+ # Profile a range of batches, e.g. from 10 to 20.
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir='./logs', profile_batch=(10,20))
+ model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
+ ```
"""
# pylint: enable=line-too-long