Invoke TensorBoard scalar v2 in `tf.compat.v1.summary.scalar`.

PiperOrigin-RevId: 422561603
Change-Id: I5261deda14315123cc38361a9bb31bdb4eaace13
diff --git a/tensorflow/python/summary/BUILD b/tensorflow/python/summary/BUILD
index c8533b0..c433e50 100644
--- a/tensorflow/python/summary/BUILD
+++ b/tensorflow/python/summary/BUILD
@@ -44,6 +44,7 @@
         "plugin_asset_test.py",
         "summary_iterator_test.py",
         "summary_test.py",
+        "summary_v2_test.py",
     ],
     python_version = "PY3",
     deps = [
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index b5fef3f..e9018a9 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -22,7 +22,7 @@
 from google.protobuf import json_format as _json_format
 
 # exports Summary, SummaryDescription, Event, TaggedRunMetadata, SessionLog
-# pylint: disable=unused-import
+# pylint: disable=unused-import, g-importing-member
 from tensorflow.core.framework.summary_pb2 import Summary
 from tensorflow.core.framework.summary_pb2 import SummaryDescription
 from tensorflow.core.framework.summary_pb2 import SummaryMetadata as _SummaryMetadata  # pylint: enable=unused-import
@@ -39,6 +39,7 @@
 from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops
 from tensorflow.python.ops import gen_summary_ops as _gen_summary_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import summary_op_util as _summary_op_util
+from tensorflow.python.ops import summary_ops_v2 as _summary_ops_v2
 
 # exports FileWriter, FileWriterCache
 # pylint: disable=unused-import
@@ -72,8 +73,19 @@
     ValueError: If tensor has the wrong shape or type.
 
   @compatibility(TF2)
-  This API is not compatible with eager execution or `tf.function`. To migrate
-  to TF2, please use `tf.summary.scalar` instead. Please check
+  For compatibility purposes, when invoked in TF2 where the outermost context is
+  eager mode, this API will check if there is a suitable TF2 summary writer
+  context available, and if so will forward this call to that writer instead. A
+  "suitable" writer context means that the writer is set as the default writer,
+  and there is an associated non-empty value for `step` (see
+  `tf.summary.SummaryWriter.as_default`, or alternatively
+  `tf.summary.experimental.set_step`). For the forwarded call, the arguments
+  here will be passed to the TF2 implementation of `tf.summary.scalar`, and the
+  return value will be an empty bytestring tensor, to avoid duplicate summary
+  writing. This forwarding is best-effort and not all arguments will be
+  preserved.
+
+  To migrate to TF2, please use `tf.summary.scalar` instead. Please check
   [Migrating tf.summary usage to
   TF 2.0](https://www.tensorflow.org/tensorboard/migrate#in_tf_1x) for concrete
   steps for migration. `tf.summary.scalar` can also log training metrics in
@@ -98,6 +110,23 @@
 
   @end_compatibility
   """
+  # Special case: invoke v2 op for TF2 users who have a v2 writer.
+  if _ops.executing_eagerly_outside_functions():
+    # Apart from an existing writer, users need to call
+    # `tf.summary.experimental.set_step` in order to invoke v2 API here.
+    if _summary_ops_v2.get_step(
+    ) is not None and _summary_ops_v2.has_default_writer():
+      # Defer the import to happen inside the symbol to prevent breakage due to
+      # missing dependency.
+      # pylint: disable=g-import-not-at-top
+      from tensorboard.summary.v2 import scalar as scalar_v2
+      # TODO(b/210992280): Handle the family argument.
+      scalar_v2(name, data=tensor)
+      # Return an empty Tensor, which will be acceptable as an input to the
+      # `tf.compat.v1.summary.merge()` API.
+      return _constant_op.constant(b'')
+
+  # Fall back to legacy v1 scalar implementation.
   if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
diff --git a/tensorflow/python/summary/summary_v2_test.py b/tensorflow/python/summary/summary_v2_test.py
new file mode 100644
index 0000000..4d5ca89
--- /dev/null
+++ b/tensorflow/python/summary/summary_v2_test.py
@@ -0,0 +1,61 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the API surface of the V1 tf.summary ops when TF2 is enabled.
+
+V1 summary ops will invoke V2 TensorBoard summary ops in eager mode.
+"""
+
+from tensorboard.summary import v2 as summary_v2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import summary_ops_v2
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary as summary_lib
+
+
+class SummaryV2Test(test.TestCase):
+
+  @test_util.run_v2_only
+  def test_scalar_summary_v2__w_writer(self):
+    """Tests scalar v2 invocation with a v2 writer."""
+    with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+      with summary_ops_v2.create_summary_file_writer('/tmp/test').as_default(
+          step=1):
+        i = constant_op.constant(2.5)
+        tensor = summary_lib.scalar('float', i)
+    # Returns empty string.
+    self.assertEqual(tensor.numpy(), b'')
+    self.assertEqual(tensor.dtype, dtypes.string)
+    mock_scalar_v2.assert_called_once_with('float', data=i)
+
+  @test_util.run_v2_only
+  def test_scalar_summary_v2__wo_writer(self):
+    """Tests scalar v2 invocation with no writer."""
+    with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+      summary_lib.scalar('float', constant_op.constant(2.5))
+    mock_scalar_v2.assert_not_called()
+
+  @test_util.run_v2_only
+  def test_scalar_summary_v2__global_step_not_set(self):
+    """Tests scalar v2 invocation when global step is not set."""
+    with test.mock.patch.object(summary_v2, 'scalar') as mock_scalar_v2:
+      with summary_ops_v2.create_summary_file_writer('/tmp/test').as_default():
+        summary_lib.scalar('float', constant_op.constant(2.5))
+    mock_scalar_v2.assert_not_called()
+
+
+if __name__ == '__main__':
+  test.main()