Make tf.summary.write() take in callable as `tensor` parameter.
PiperOrigin-RevId: 278662061
Change-Id: I5ac01e5167341645b27abdc4df8f461f1d34a6ec
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index b61c071..868d256 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -615,7 +615,10 @@
Args:
tag: string tag used to identify the summary (e.g. in TensorBoard), usually
generated with `tf.summary.summary_scope`
- tensor: the Tensor holding the summary data to write
+ tensor: the Tensor holding the summary data to write or a callable that
+ returns this Tensor. If a callable is passed, it will only be called when
+ a default SummaryWriter exists and the recording condition specified by
+ `record_if()` is met.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
@@ -649,10 +652,12 @@
"""Record the actual summary and return True."""
# Note the identity to move the tensor to the CPU.
with ops.device("cpu:0"):
+ summary_tensor = tensor() if callable(tensor) else array_ops.identity(
+ tensor)
write_summary_op = gen_summary_ops.write_summary(
_summary_state.writer._resource, # pylint: disable=protected-access
step,
- array_ops.identity(tensor),
+ summary_tensor,
tag,
serialized_metadata,
name=scope)