Invoke TensorBoard scalar v2 in `tf.compat.v1.summary.scalar`.
PiperOrigin-RevId: 422561603
Change-Id: I5261deda14315123cc38361a9bb31bdb4eaace13
diff --git a/tensorflow/python/summary/BUILD b/tensorflow/python/summary/BUILD
index c8533b0..c433e50 100644
--- a/tensorflow/python/summary/BUILD
+++ b/tensorflow/python/summary/BUILD
@@ -44,6 +44,7 @@
"plugin_asset_test.py",
"summary_iterator_test.py",
"summary_test.py",
+ "summary_v2_test.py",
],
python_version = "PY3",
deps = [
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index b5fef3f..e9018a9 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -22,7 +22,7 @@
from google.protobuf import json_format as _json_format
# exports Summary, SummaryDescription, Event, TaggedRunMetadata, SessionLog
-# pylint: disable=unused-import
+# pylint: disable=unused-import, g-importing-member
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.framework.summary_pb2 import SummaryDescription
from tensorflow.core.framework.summary_pb2 import SummaryMetadata as _SummaryMetadata # pylint: enable=unused-import
@@ -39,6 +39,7 @@
from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops
from tensorflow.python.ops import gen_summary_ops as _gen_summary_ops # pylint: disable=unused-import
from tensorflow.python.ops import summary_op_util as _summary_op_util
+from tensorflow.python.ops import summary_ops_v2 as _summary_ops_v2
# exports FileWriter, FileWriterCache
# pylint: disable=unused-import
@@ -72,8 +73,19 @@
ValueError: If tensor has the wrong shape or type.
@compatibility(TF2)
- This API is not compatible with eager execution or `tf.function`. To migrate
- to TF2, please use `tf.summary.scalar` instead. Please check
+ For compatibility purposes, when invoked in TF2 where the outermost context is
+ eager mode, this API will check if there is a suitable TF2 summary writer
+ context available, and if so will forward this call to that writer instead. A
+ "suitable" writer context means that the writer is set as the default writer,
+ and there is an associated non-empty value for `step` (see
+ `tf.summary.SummaryWriter.as_default`, or alternatively
+ `tf.summary.experimental.set_step`). For the forwarded call, the arguments
+ here will be passed to the TF2 implementation of `tf.summary.scalar`, and the
+ return value will be an empty bytestring tensor, to avoid duplicate summary
+ writing. This forwarding is best-effort and not all arguments will be
+ preserved.
+
+ To migrate to TF2, please use `tf.summary.scalar` instead. Please check
[Migrating tf.summary usage to
TF 2.0](https://www.tensorflow.org/tensorboard/migrate#in_tf_1x) for concrete
steps for migration. `tf.summary.scalar` can also log training metrics in
@@ -98,6 +110,23 @@
@end_compatibility
"""
+ # Special case: invoke v2 op for TF2 users who have a v2 writer.
+ if _ops.executing_eagerly_outside_functions():
+ # Apart from an existing writer, users need to call
+ # `tf.summary.experimental.set_step` in order to invoke v2 API here.
+ if _summary_ops_v2.get_step(
+ ) is not None and _summary_ops_v2.has_default_writer():
+ # Defer the import to happen inside the symbol to prevent breakage due to
+ # missing dependency.
+ # pylint: disable=g-import-not-at-top
+ from tensorboard.summary.v2 import scalar as scalar_v2
+ # TODO(b/210992280): Handle the family argument.
+ scalar_v2(name, data=tensor)
+ # Return an empty Tensor, which will be acceptable as an input to the
+ # `tf.compat.v1.summary.merge()` API.
+ return _constant_op.constant(b'')
+
+ # Fall back to legacy v1 scalar implementation.
if _distribute_summary_op_util.skip_summary():
return _constant_op.constant('')
with _summary_op_util.summary_scope(
diff --git a/tensorflow/python/summary/summary_v2_test.py b/tensorflow/python/summary/summary_v2_test.py
new file mode 100644
index 0000000..4d5ca89
--- /dev/null
+++ b/tensorflow/python/summary/summary_v2_test.py
@@ -0,0 +1,61 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the API surface of the V1 tf.summary ops when TF2 is enabled.
+
+V1 summary ops will invoke V2 TensorBoard summary ops in eager mode.
+"""
+
+from tensorboard.summary import v2 as summary_v2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import summary_ops_v2
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary as summary_lib
+
+
+class SummaryV2Test(test.TestCase):
+
+ @test_util.run_v2_only
+ def test_scalar_summary_v2__w_writer(self):
+ """Tests scalar v2 invocation with a v2 writer."""
+ with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+ with summary_ops_v2.create_summary_file_writer('/tmp/test').as_default(
+ step=1):
+ i = constant_op.constant(2.5)
+ tensor = summary_lib.scalar('float', i)
+ # Returns empty string.
+ self.assertEqual(tensor.numpy(), b'')
+ self.assertEqual(tensor.dtype, dtypes.string)
+ mock_scalar_v2.assert_called_once_with('float', data=i)
+
+ @test_util.run_v2_only
+ def test_scalar_summary_v2__wo_writer(self):
+ """Tests scalar v2 invocation with no writer."""
+ with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+ summary_lib.scalar('float', constant_op.constant(2.5))
+ mock_scalar_v2.assert_not_called()
+
+ @test_util.run_v2_only
+ def test_scalar_summary_v2__global_step_not_set(self):
+ """Tests scalar v2 invocation when global step is not set."""
+ with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+ with summary_ops_v2.create_summary_file_writer('/tmp/test').as_default():
+ summary_lib.scalar('float', constant_op.constant(2.5))
+ mock_scalar_v2.assert_not_called()
+
+
+if __name__ == '__main__':
+ test.main()