Enable experimental SavedModel support for SummaryWriter using TrackableResource

This introduces a new `experimental_trackable` boolean argument to `tf.summary.create_file_writer()`. If True (default is False), the created `SummaryWriter` will be a `tf.saved_model.experimental.TrackableResource`, allowing the writer to be saved in a SavedModel as a property on a `tf.Module`, and then used within @tf.function methods on the module.

There are a couple significant limitations to this support:

1) The logdir passed to `create_file_writer()` will be baked into the resulting SavedModel and cannot be changed later. One workaround is to specify a relative path, and then ensure that any code loading the SavedModel is using a working directory into which logs can be written.

2) Initializing the SavedModel resources (e.g. when loading it back into Python) will recreate the writer and open a new file, which may be a surprising side effect. Possible future solutions would be deferring either the writer creation or the opening of the file.

PiperOrigin-RevId: 367347580
Change-Id: Ib6b614aaf5edb248e7b6de690d9f3d367a4c97cd
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1bd940f..e5bd167 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2606,6 +2606,7 @@
         ":control_flow_ops",
         ":math_ops",
         ":resource_variable_ops",
+        ":resource_variable_ops_gen",
         ":summary_op_util",
         ":summary_ops_gen",
         ":training_util",
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5ab14f8..78c972c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1292,6 +1292,9 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/saved_model:load",
+        "//tensorflow/python/saved_model:loader",
+        "//tensorflow/python/saved_model:tag_constants",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 4f4f5cf..62b5235 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -39,12 +39,17 @@
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import tf_record
+from tensorflow.python.module import module
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import summary_ops_v2 as summary_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import load as saved_model_load
+from tensorflow.python.saved_model import loader as saved_model_loader
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model import tag_constants
 
 
 class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
@@ -977,6 +982,130 @@
       self.assertNotIn(eventfile, get_open_filenames())
 
 
+class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase):
+
+  def testWriter_savedAsModuleProperty_loadInEagerMode(self):
+    with context.eager_mode():
+      class Model(module.Module):
+
+        def __init__(self, model_dir):
+          self._writer = summary_ops.create_file_writer_v2(
+              model_dir, experimental_trackable=True)
+
+        @def_function.function(input_signature=[
+            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
+        ])
+        def train(self, step):
+          with self._writer.as_default():
+            summary_ops.write('tag', 'foo', step=step)
+          return constant_op.constant(0)
+
+      logdir = self.get_temp_dir()
+      to_export = Model(logdir)
+      pre_save_files = set(events_from_multifile_logdir(logdir))
+      export_dir = os.path.join(logdir, 'export')
+      saved_model_save.save(
+          to_export, export_dir, signatures={'train': to_export.train})
+
+    # Reset context to ensure we don't share any resources with saving code.
+    context._reset_context()  # pylint: disable=protected-access
+    with context.eager_mode():
+      restored = saved_model_load.load(export_dir)
+      restored.train(1)
+      restored.train(2)
+      post_restore_files = set(events_from_multifile_logdir(logdir))
+      restored2 = saved_model_load.load(export_dir)
+      restored2.train(3)
+      restored2.train(4)
+      files_to_events = events_from_multifile_logdir(logdir)
+      post_restore2_files = set(files_to_events)
+      self.assertLen(files_to_events, 3)
+      def unwrap_singleton(iterable):
+        self.assertLen(iterable, 1)
+        return next(iter(iterable))
+      restore_file = unwrap_singleton(post_restore_files - pre_save_files)
+      restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
+      restore_events = files_to_events[restore_file]
+      restore2_events = files_to_events[restore2_file]
+      self.assertLen(restore_events, 3)
+      self.assertEqual(1, restore_events[1].step)
+      self.assertEqual(2, restore_events[2].step)
+      self.assertLen(restore2_events, 3)
+      self.assertEqual(3, restore2_events[1].step)
+      self.assertEqual(4, restore2_events[2].step)
+
+  def testWriter_savedAsModuleProperty_loadInGraphMode(self):
+    with context.eager_mode():
+
+      class Model(module.Module):
+
+        def __init__(self, model_dir):
+          self._writer = summary_ops.create_file_writer_v2(
+              model_dir, experimental_trackable=True)
+
+        @def_function.function(input_signature=[
+            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
+        ])
+        def train(self, step):
+          with self._writer.as_default():
+            summary_ops.write('tag', 'foo', step=step)
+          return constant_op.constant(0)
+
+      logdir = self.get_temp_dir()
+      to_export = Model(logdir)
+      pre_save_files = set(events_from_multifile_logdir(logdir))
+      export_dir = os.path.join(logdir, 'export')
+      saved_model_save.save(
+          to_export, export_dir, signatures={'train': to_export.train})
+
+    # Reset context to ensure we don't share any resources with saving code.
+    context._reset_context()  # pylint: disable=protected-access
+
+    def load_and_run_model(sess, input_values):
+      """Load and run the SavedModel signature in the TF 1.x style."""
+      model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir)
+      signature = model.signature_def['train']
+      inputs = list(signature.inputs.values())
+      assert len(inputs) == 1, inputs
+      outputs = list(signature.outputs.values())
+      assert len(outputs) == 1, outputs
+      input_tensor = sess.graph.get_tensor_by_name(inputs[0].name)
+      output_tensor = sess.graph.get_tensor_by_name(outputs[0].name)
+      for v in input_values:
+        sess.run(output_tensor, feed_dict={input_tensor: v})
+
+    with context.graph_mode(), ops.Graph().as_default():
+      # Since writer shared_name is fixed, within a single session, all loads of
+      # this SavedModel will refer to a single writer resouce, so it will be
+      # initialized only once and write to a single file.
+      with self.session() as sess:
+        load_and_run_model(sess, [1, 2])
+        load_and_run_model(sess, [3, 4])
+      post_restore_files = set(events_from_multifile_logdir(logdir))
+      # New session will recreate the resource and write to a second file.
+      with self.session() as sess:
+        load_and_run_model(sess, [5, 6])
+      files_to_events = events_from_multifile_logdir(logdir)
+      post_restore2_files = set(files_to_events)
+
+    self.assertLen(files_to_events, 3)
+    def unwrap_singleton(iterable):
+      self.assertLen(iterable, 1)
+      return next(iter(iterable))
+    restore_file = unwrap_singleton(post_restore_files - pre_save_files)
+    restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
+    restore_events = files_to_events[restore_file]
+    restore2_events = files_to_events[restore2_file]
+    self.assertLen(restore_events, 5)
+    self.assertEqual(1, restore_events[1].step)
+    self.assertEqual(2, restore_events[2].step)
+    self.assertEqual(3, restore_events[3].step)
+    self.assertEqual(4, restore_events[4].step)
+    self.assertLen(restore2_events, 3)
+    self.assertEqual(5, restore2_events[1].step)
+    self.assertEqual(6, restore2_events[2].step)
+
+
 class NoopWriterTest(test_util.TensorFlowTestCase):
 
   def testNoopWriter_doesNothing(self):
@@ -1406,6 +1535,23 @@
   return events_from_file(os.path.join(logdir, files[0]))
 
 
+def events_from_multifile_logdir(logdir):
+  """Returns map of filename to events for all `tfevents` files in the logdir.
+
+  Args:
+    logdir: The directory from which to load events.
+
+  Returns:
+    A dict mapping from relative filenames to lists of tf.Event protos.
+
+  Raises:
+    AssertionError: If logdir does not contain exactly one file.
+  """
+  assert gfile.Exists(logdir)
+  files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file]
+  return {file: events_from_file(os.path.join(logdir, file)) for file in files}
+
+
 def to_numpy(summary_value):
   return tensor_util.MakeNdarray(summary_value.tensor)
 
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 01c622d..6eebd2c 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -40,12 +40,14 @@
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import gen_summary_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import summary_op_util
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import training_util
+from tensorflow.python.training.tracking import tracking
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
@@ -302,18 +304,21 @@
 class _ResourceSummaryWriter(SummaryWriter):
   """Implementation of SummaryWriter using a SummaryWriterInterface resource."""
 
-  def  __init__(self, shared_name, init_op_fn, name=None):
-    self._resource = gen_summary_ops.summary_writer(
-        shared_name=shared_name, name=name)
-    self._init_op_fn = init_op_fn
+  def __init__(self, create_fn, init_op_fn):
+    self._resource = create_fn()
     self._init_op = init_op_fn(self._resource)
     self._closed = False
     if context.executing_eagerly():
-      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
-          handle=self._resource, handle_device="cpu:0")
+      self._set_up_resource_deleter()
     else:
       ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, self._init_op)
 
+  # Extension point to be overridden by subclasses to customize deletion.
+
+  def _set_up_resource_deleter(self):
+    self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+        handle=self._resource, handle_device="cpu:0")
+
   def set_as_default(self, step=None):
     """See `SummaryWriter.set_as_default`."""
     if context.executing_eagerly() and self._closed:
@@ -352,6 +357,40 @@
         self._closed = True
 
 
+class _TrackableResourceSummaryWriter(_ResourceSummaryWriter,
+                                      tracking.TrackableResource):
+  """A `_ResourceSummaryWriter` subclass that implements `TrackableResource`."""
+
+  def __init__(self, create_fn, init_op_fn):
+    # Resolve multiple inheritance via explicit calls to __init__() on parents.
+    tracking.TrackableResource.__init__(self, device="cpu:0")
+    self._create_fn = create_fn
+    self._init_op_fn = init_op_fn
+    # Pass .resource_handle into _ResourceSummaryWriter parent class rather than
+    # create_fn, to ensure it accesses the resource handle only through the
+    # cached property so that everything is using a single resource handle.
+    _ResourceSummaryWriter.__init__(
+        self, create_fn=lambda: self.resource_handle, init_op_fn=init_op_fn)
+
+  # Override for TrackableResource implementation.
+  def _create_resource(self):
+    return self._create_fn()
+
+  # Override for TrackableResource implementation.
+  def _initialize(self):
+    return self._init_op_fn(self.resource_handle)
+
+  # Override for TrackableResource implementation.
+  def _destroy_resource(self):
+    gen_resource_variable_ops.destroy_resource_op(
+        self.resource_handle, ignore_lookup_error=True)
+
+  def _set_up_resource_deleter(self):
+    # Override to suppress ResourceSummaryWriter implementation; we don't need
+    # the deleter since TrackableResource.__del__() handles it for us.
+    pass
+
+
 class _LegacyResourceSummaryWriter(SummaryWriter):
   """Legacy resource-backed SummaryWriter for tf.contrib.summary."""
 
@@ -448,7 +487,8 @@
                           max_queue=None,
                           flush_millis=None,
                           filename_suffix=None,
-                          name=None):
+                          name=None,
+                          experimental_trackable=False):
   """Creates a summary file writer for the given log directory.
 
   Args:
@@ -458,6 +498,9 @@
     flush_millis: the largest interval between flushes. Defaults to 120,000.
     filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
     name: a name for the op that creates the writer.
+    experimental_trackable: a boolean that controls whether the returned writer
+      will be a `TrackableResource`, which makes it compatible with SavedModel
+      when used as a `tf.Module` property.
 
   Returns:
     A SummaryWriter object.
@@ -482,20 +525,29 @@
         flush_millis = constant_op.constant(2 * 60 * 1000)
       if filename_suffix is None:
         filename_suffix = constant_op.constant(".v2")
-      # Use a unique shared_name to prevent resource sharing.
-      if context.executing_eagerly():
-        shared_name = context.shared_name()
+
+      def create_fn():
+        # Use unique shared_name to prevent resource sharing in eager mode, but
+        # otherwise use a fixed shared_name to allow SavedModel TF 1.x loading.
+        if context.executing_eagerly():
+          shared_name = context.shared_name()
+        else:
+          shared_name = ops.name_from_scope_name(scope)  # pylint: disable=protected-access
+        return gen_summary_ops.summary_writer(
+            shared_name=shared_name, name=name)
+
+      init_op_fn = functools.partial(
+          gen_summary_ops.create_summary_file_writer,
+          logdir=logdir,
+          max_queue=max_queue,
+          flush_millis=flush_millis,
+          filename_suffix=filename_suffix)
+      if experimental_trackable:
+        return _TrackableResourceSummaryWriter(
+            create_fn=create_fn, init_op_fn=init_op_fn)
       else:
-        shared_name = ops.name_from_scope_name(scope)  # pylint: disable=protected-access
-      return _ResourceSummaryWriter(
-          shared_name=shared_name,
-          init_op_fn=functools.partial(
-              gen_summary_ops.create_summary_file_writer,
-              logdir=logdir,
-              max_queue=max_queue,
-              flush_millis=flush_millis,
-              filename_suffix=filename_suffix),
-          name=name)
+        return _ResourceSummaryWriter(
+            create_fn=create_fn, init_op_fn=init_op_fn)
 
 
 def create_file_writer(logdir,
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
index 237e4bf..76cc940 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
@@ -14,7 +14,7 @@
   }
   member_method {
     name: "create_file_writer"
-    argspec: "args=[\'logdir\', \'max_queue\', \'flush_millis\', \'filename_suffix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'logdir\', \'max_queue\', \'flush_millis\', \'filename_suffix\', \'name\', \'experimental_trackable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "create_noop_writer"