Enable experimental SavedModel support for SummaryWriter using TrackableResource
This introduces a new `experimental_trackable` boolean argument to `tf.summary.create_file_writer()`. If True (default is False), the created `SummaryWriter` will be a `tf.saved_model.experimental.TrackableResource`, allowing the writer to be saved in a SavedModel as a property on a `tf.Module`, and then used within @tf.function methods on the module.
There are a couple significant limitations to this support:
1) The logdir passed to `create_file_writer()` will be baked into the resulting SavedModel and cannot be changed later. One workaround is to specify a relative path, and then ensure that any code loading the SavedModel is using a working directory into which logs can be written.
2) Initializing the SavedModel resources (e.g. when loading it back into Python) will recreate the writer and open a new file, which may be a surprising side effect. Possible future solutions would be deferring either the writer creation or the opening of the file.
PiperOrigin-RevId: 367347580
Change-Id: Ib6b614aaf5edb248e7b6de690d9f3d367a4c97cd
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1bd940f..e5bd167 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2606,6 +2606,7 @@
":control_flow_ops",
":math_ops",
":resource_variable_ops",
+ ":resource_variable_ops_gen",
":summary_op_util",
":summary_ops_gen",
":training_util",
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5ab14f8..78c972c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1292,6 +1292,9 @@
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
+ "//tensorflow/python/saved_model:load",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 4f4f5cf..62b5235 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -39,12 +39,17 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
+from tensorflow.python.module import module
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import load as saved_model_load
+from tensorflow.python.saved_model import loader as saved_model_loader
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model import tag_constants
class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
@@ -977,6 +982,130 @@
self.assertNotIn(eventfile, get_open_filenames())
+class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase):
+
+ def testWriter_savedAsModuleProperty_loadInEagerMode(self):
+ with context.eager_mode():
+ class Model(module.Module):
+
+ def __init__(self, model_dir):
+ self._writer = summary_ops.create_file_writer_v2(
+ model_dir, experimental_trackable=True)
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
+ ])
+ def train(self, step):
+ with self._writer.as_default():
+ summary_ops.write('tag', 'foo', step=step)
+ return constant_op.constant(0)
+
+ logdir = self.get_temp_dir()
+ to_export = Model(logdir)
+ pre_save_files = set(events_from_multifile_logdir(logdir))
+ export_dir = os.path.join(logdir, 'export')
+ saved_model_save.save(
+ to_export, export_dir, signatures={'train': to_export.train})
+
+ # Reset context to ensure we don't share any resources with saving code.
+ context._reset_context() # pylint: disable=protected-access
+ with context.eager_mode():
+ restored = saved_model_load.load(export_dir)
+ restored.train(1)
+ restored.train(2)
+ post_restore_files = set(events_from_multifile_logdir(logdir))
+ restored2 = saved_model_load.load(export_dir)
+ restored2.train(3)
+ restored2.train(4)
+ files_to_events = events_from_multifile_logdir(logdir)
+ post_restore2_files = set(files_to_events)
+ self.assertLen(files_to_events, 3)
+ def unwrap_singleton(iterable):
+ self.assertLen(iterable, 1)
+ return next(iter(iterable))
+ restore_file = unwrap_singleton(post_restore_files - pre_save_files)
+ restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
+ restore_events = files_to_events[restore_file]
+ restore2_events = files_to_events[restore2_file]
+ self.assertLen(restore_events, 3)
+ self.assertEqual(1, restore_events[1].step)
+ self.assertEqual(2, restore_events[2].step)
+ self.assertLen(restore2_events, 3)
+ self.assertEqual(3, restore2_events[1].step)
+ self.assertEqual(4, restore2_events[2].step)
+
+ def testWriter_savedAsModuleProperty_loadInGraphMode(self):
+ with context.eager_mode():
+
+ class Model(module.Module):
+
+ def __init__(self, model_dir):
+ self._writer = summary_ops.create_file_writer_v2(
+ model_dir, experimental_trackable=True)
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
+ ])
+ def train(self, step):
+ with self._writer.as_default():
+ summary_ops.write('tag', 'foo', step=step)
+ return constant_op.constant(0)
+
+ logdir = self.get_temp_dir()
+ to_export = Model(logdir)
+ pre_save_files = set(events_from_multifile_logdir(logdir))
+ export_dir = os.path.join(logdir, 'export')
+ saved_model_save.save(
+ to_export, export_dir, signatures={'train': to_export.train})
+
+ # Reset context to ensure we don't share any resources with saving code.
+ context._reset_context() # pylint: disable=protected-access
+
+ def load_and_run_model(sess, input_values):
+ """Load and run the SavedModel signature in the TF 1.x style."""
+ model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir)
+ signature = model.signature_def['train']
+ inputs = list(signature.inputs.values())
+ assert len(inputs) == 1, inputs
+ outputs = list(signature.outputs.values())
+ assert len(outputs) == 1, outputs
+ input_tensor = sess.graph.get_tensor_by_name(inputs[0].name)
+ output_tensor = sess.graph.get_tensor_by_name(outputs[0].name)
+ for v in input_values:
+ sess.run(output_tensor, feed_dict={input_tensor: v})
+
+ with context.graph_mode(), ops.Graph().as_default():
+ # Since writer shared_name is fixed, within a single session, all loads of
+ # this SavedModel will refer to a single writer resouce, so it will be
+ # initialized only once and write to a single file.
+ with self.session() as sess:
+ load_and_run_model(sess, [1, 2])
+ load_and_run_model(sess, [3, 4])
+ post_restore_files = set(events_from_multifile_logdir(logdir))
+ # New session will recreate the resource and write to a second file.
+ with self.session() as sess:
+ load_and_run_model(sess, [5, 6])
+ files_to_events = events_from_multifile_logdir(logdir)
+ post_restore2_files = set(files_to_events)
+
+ self.assertLen(files_to_events, 3)
+ def unwrap_singleton(iterable):
+ self.assertLen(iterable, 1)
+ return next(iter(iterable))
+ restore_file = unwrap_singleton(post_restore_files - pre_save_files)
+ restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
+ restore_events = files_to_events[restore_file]
+ restore2_events = files_to_events[restore2_file]
+ self.assertLen(restore_events, 5)
+ self.assertEqual(1, restore_events[1].step)
+ self.assertEqual(2, restore_events[2].step)
+ self.assertEqual(3, restore_events[3].step)
+ self.assertEqual(4, restore_events[4].step)
+ self.assertLen(restore2_events, 3)
+ self.assertEqual(5, restore2_events[1].step)
+ self.assertEqual(6, restore2_events[2].step)
+
+
class NoopWriterTest(test_util.TensorFlowTestCase):
def testNoopWriter_doesNothing(self):
@@ -1406,6 +1535,23 @@
return events_from_file(os.path.join(logdir, files[0]))
+def events_from_multifile_logdir(logdir):
+ """Returns map of filename to events for all `tfevents` files in the logdir.
+
+ Args:
+ logdir: The directory from which to load events.
+
+ Returns:
+ A dict mapping from relative filenames to lists of tf.Event protos.
+
+ Raises:
+ AssertionError: If logdir does not contain exactly one file.
+ """
+ assert gfile.Exists(logdir)
+ files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file]
+ return {file: events_from_file(os.path.join(logdir, file)) for file in files}
+
+
def to_numpy(summary_value):
return tensor_util.MakeNdarray(summary_value.tensor)
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 01c622d..6eebd2c 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -40,12 +40,14 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_summary_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
+from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@@ -302,18 +304,21 @@
class _ResourceSummaryWriter(SummaryWriter):
"""Implementation of SummaryWriter using a SummaryWriterInterface resource."""
- def __init__(self, shared_name, init_op_fn, name=None):
- self._resource = gen_summary_ops.summary_writer(
- shared_name=shared_name, name=name)
- self._init_op_fn = init_op_fn
+ def __init__(self, create_fn, init_op_fn):
+ self._resource = create_fn()
self._init_op = init_op_fn(self._resource)
self._closed = False
if context.executing_eagerly():
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device="cpu:0")
+ self._set_up_resource_deleter()
else:
ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, self._init_op)
+ # Extension point to be overridden by subclasses to customize deletion.
+
+ def _set_up_resource_deleter(self):
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device="cpu:0")
+
def set_as_default(self, step=None):
"""See `SummaryWriter.set_as_default`."""
if context.executing_eagerly() and self._closed:
@@ -352,6 +357,40 @@
self._closed = True
+class _TrackableResourceSummaryWriter(_ResourceSummaryWriter,
+ tracking.TrackableResource):
+ """A `_ResourceSummaryWriter` subclass that implements `TrackableResource`."""
+
+ def __init__(self, create_fn, init_op_fn):
+ # Resolve multiple inheritance via explicit calls to __init__() on parents.
+ tracking.TrackableResource.__init__(self, device="cpu:0")
+ self._create_fn = create_fn
+ self._init_op_fn = init_op_fn
+ # Pass .resource_handle into _ResourceSummaryWriter parent class rather than
+ # create_fn, to ensure it accesses the resource handle only through the
+ # cached property so that everything is using a single resource handle.
+ _ResourceSummaryWriter.__init__(
+ self, create_fn=lambda: self.resource_handle, init_op_fn=init_op_fn)
+
+ # Override for TrackableResource implementation.
+ def _create_resource(self):
+ return self._create_fn()
+
+ # Override for TrackableResource implementation.
+ def _initialize(self):
+ return self._init_op_fn(self.resource_handle)
+
+ # Override for TrackableResource implementation.
+ def _destroy_resource(self):
+ gen_resource_variable_ops.destroy_resource_op(
+ self.resource_handle, ignore_lookup_error=True)
+
+ def _set_up_resource_deleter(self):
+ # Override to suppress ResourceSummaryWriter implementation; we don't need
+ # the deleter since TrackableResource.__del__() handles it for us.
+ pass
+
+
class _LegacyResourceSummaryWriter(SummaryWriter):
"""Legacy resource-backed SummaryWriter for tf.contrib.summary."""
@@ -448,7 +487,8 @@
max_queue=None,
flush_millis=None,
filename_suffix=None,
- name=None):
+ name=None,
+ experimental_trackable=False):
"""Creates a summary file writer for the given log directory.
Args:
@@ -458,6 +498,9 @@
flush_millis: the largest interval between flushes. Defaults to 120,000.
filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
name: a name for the op that creates the writer.
+ experimental_trackable: a boolean that controls whether the returned writer
+ will be a `TrackableResource`, which makes it compatible with SavedModel
+ when used as a `tf.Module` property.
Returns:
A SummaryWriter object.
@@ -482,20 +525,29 @@
flush_millis = constant_op.constant(2 * 60 * 1000)
if filename_suffix is None:
filename_suffix = constant_op.constant(".v2")
- # Use a unique shared_name to prevent resource sharing.
- if context.executing_eagerly():
- shared_name = context.shared_name()
+
+ def create_fn():
+ # Use unique shared_name to prevent resource sharing in eager mode, but
+ # otherwise use a fixed shared_name to allow SavedModel TF 1.x loading.
+ if context.executing_eagerly():
+ shared_name = context.shared_name()
+ else:
+ shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access
+ return gen_summary_ops.summary_writer(
+ shared_name=shared_name, name=name)
+
+ init_op_fn = functools.partial(
+ gen_summary_ops.create_summary_file_writer,
+ logdir=logdir,
+ max_queue=max_queue,
+ flush_millis=flush_millis,
+ filename_suffix=filename_suffix)
+ if experimental_trackable:
+ return _TrackableResourceSummaryWriter(
+ create_fn=create_fn, init_op_fn=init_op_fn)
else:
- shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access
- return _ResourceSummaryWriter(
- shared_name=shared_name,
- init_op_fn=functools.partial(
- gen_summary_ops.create_summary_file_writer,
- logdir=logdir,
- max_queue=max_queue,
- flush_millis=flush_millis,
- filename_suffix=filename_suffix),
- name=name)
+ return _ResourceSummaryWriter(
+ create_fn=create_fn, init_op_fn=init_op_fn)
def create_file_writer(logdir,
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
index 237e4bf..76cc940 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
@@ -14,7 +14,7 @@
}
member_method {
name: "create_file_writer"
- argspec: "args=[\'logdir\', \'max_queue\', \'flush_millis\', \'filename_suffix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'logdir\', \'max_queue\', \'flush_millis\', \'filename_suffix\', \'name\', \'experimental_trackable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\'], "
}
member_method {
name: "create_noop_writer"