Add support for shared objects in keras when (de)serializing.
Previously, objects shared between multiple layers would be duplicated when
saving and loading. This commit adds a unique ID for keras objects when
serializing that allows us to correctly create only a single instance of
shared objects when deserializing.
PiperOrigin-RevId: 354996891
Change-Id: Idfd055f430c7ea7e25c459ed7715a370d7a632c9
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 743b4c0..a3aa265 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -671,12 +671,13 @@
Raises:
ValueError: In case of improperly formatted config dict.
"""
- input_tensors, output_tensors, created_layers = reconstruct_from_config(
- config, custom_objects)
- model = cls(inputs=input_tensors, outputs=output_tensors,
- name=config.get('name'))
- connect_ancillary_layers(model, created_layers)
- return model
+ with generic_utils.SharedObjectLoadingScope():
+ input_tensors, output_tensors, created_layers = reconstruct_from_config(
+ config, custom_objects)
+ model = cls(inputs=input_tensors, outputs=output_tensors,
+ name=config.get('name'))
+ connect_ancillary_layers(model, created_layers)
+ return model
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
@@ -1346,21 +1347,23 @@
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
- for layer in network.layers: # From the earliest layers on.
- filtered_inbound_nodes = []
- for original_node_index, node in enumerate(layer._inbound_nodes):
- node_key = _make_node_key(layer.name, original_node_index)
- if node_key in network._network_nodes and not node.is_input:
- # The node is relevant to the model:
- # add to filtered_inbound_nodes.
- node_data = node.serialize(_make_node_key, node_conversion_map)
- filtered_inbound_nodes.append(node_data)
- layer_config = serialize_layer_fn(layer)
- layer_config['name'] = layer.name
- layer_config['inbound_nodes'] = filtered_inbound_nodes
- layer_configs.append(layer_config)
- config['layers'] = layer_configs
+ with generic_utils.SharedObjectSavingScope():
+ for layer in network.layers: # From the earliest layers on.
+ filtered_inbound_nodes = []
+ for original_node_index, node in enumerate(layer._inbound_nodes):
+ node_key = _make_node_key(layer.name, original_node_index)
+ if node_key in network._network_nodes and not node.is_input:
+ # The node is relevant to the model:
+ # add to filtered_inbound_nodes.
+ node_data = node.serialize(_make_node_key, node_conversion_map)
+ filtered_inbound_nodes.append(node_data)
+
+ layer_config = serialize_layer_fn(layer)
+ layer_config['name'] = layer.name
+ layer_config['inbound_nodes'] = filtered_inbound_nodes
+ layer_configs.append(layer_config)
+ config['layers'] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index b16e0d6..0b19f4e 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -393,6 +393,10 @@
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
+ `clone_model` will not preserve the uniqueness of shared objects within the
+ model (e.g. a single variable attached to two distinct layers will be
+ restored as two separate variables).
+
Args:
model: Instance of `Model`
(could be a functional model or a Sequential model).
@@ -420,15 +424,16 @@
Raises:
ValueError: in case of invalid `model` argument value.
"""
- if clone_function is None:
- clone_function = _clone_layer
+ with generic_utils.DisableSharedObjectScope():
+ if clone_function is None:
+ clone_function = _clone_layer
- if isinstance(model, Sequential):
- return _clone_sequential_model(
- model, input_tensors=input_tensors, layer_fn=clone_function)
- else:
- return _clone_functional_model(
- model, input_tensors=input_tensors, layer_fn=clone_function)
+ if isinstance(model, Sequential):
+ return _clone_sequential_model(
+ model, input_tensors=input_tensors, layer_fn=clone_function)
+ else:
+ return _clone_functional_model(
+ model, input_tensors=input_tensors, layer_fn=clone_function)
# "Clone" a subclassed model by reseting all of the attributes.
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 0ece5ac..12d1c39f 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -245,6 +245,28 @@
loss = model.train_on_batch(x, y)
self.assertEqual(float(loss), 0.)
+ def test_clone_rnn(self):
+ # Test cloning a model with multiple cells in an RNN. This exercises a
+ # few "fancier" features such as the `Bidrectional` wrapper and
+ # `StackedRNNCells` under the hood.
+ inputs = keras.Input(shape=(3, 3))
+ cells = [
+ keras.layers.LSTMCell(
+ units=32,
+ enable_caching_device=True,
+ implementation=2,
+ activation='relu')]
+ rnn = keras.layers.RNN(cells, return_sequences=True)
+ outputs = keras.layers.Bidirectional(rnn)(inputs)
+ outputs = keras.layers.Dense(
+ 12, activation='softmax', name='scores')(outputs)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ model.compile(
+ loss=keras.losses.CategoricalCrossentropy(),
+ optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.01),
+ metrics=['accuracy'])
+ keras.models.clone_model(model)
+
def test_model_cloning_invalid_use_cases(self):
seq_model = keras.models.Sequential()
seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index d4749fc..ef7f699 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -148,8 +148,9 @@
hdf5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer)
else:
- saved_model_save.save(model, filepath, overwrite, include_optimizer,
- signatures, options, save_traces)
+ with generic_utils.SharedObjectSavingScope():
+ saved_model_save.save(model, filepath, overwrite, include_optimizer,
+ signatures, options, save_traces)
@keras_export('keras.models.load_model')
@@ -194,17 +195,18 @@
ImportError: if loading from an hdf5 file and h5py is not available.
IOError: In case of an invalid savefile.
"""
- with generic_utils.CustomObjectScope(custom_objects or {}):
- with load_context.load_context(options):
- if (h5py is not None and
- (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
- return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
- compile)
+ with generic_utils.SharedObjectLoadingScope():
+ with generic_utils.CustomObjectScope(custom_objects or {}):
+ with load_context.load_context(options):
+ if (h5py is not None and
+ (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
+ return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
+ compile)
- filepath = path_to_string(filepath)
- if isinstance(filepath, six.string_types):
- loader_impl.parse_saved_model(filepath)
- return saved_model_load.load(filepath, compile, options)
+ filepath = path_to_string(filepath)
+ if isinstance(filepath, six.string_types):
+ loader_impl.parse_saved_model(filepath)
+ return saved_model_load.load(filepath, compile, options)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 00c7bb2..20a779b 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+import collections
import os
import shutil
import sys
@@ -25,12 +26,14 @@
from absl.testing import parameterized
import numpy as np
+from six import string_types
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import combinations
@@ -859,6 +862,125 @@
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
expected)
+ @combinations.generate(combinations.combine(mode=['eager']))
+ def test_shared_objects(self):
+ class OuterLayer(keras.layers.Layer):
+
+ def __init__(self, inner_layer):
+ super(OuterLayer, self).__init__()
+ self.inner_layer = inner_layer
+
+ def call(self, inputs):
+ return self.inner_layer(inputs)
+
+ def get_config(self):
+ return {
+ 'inner_layer': generic_utils.serialize_keras_object(
+ self.inner_layer)
+ }
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(generic_utils.deserialize_keras_object(
+ config['inner_layer']))
+
+ class InnerLayer(keras.layers.Layer):
+
+ def __init__(self):
+ super(InnerLayer, self).__init__()
+ self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
+
+ def call(self, inputs):
+ return self.v + inputs
+
+ @classmethod
+ def from_config(cls, config):
+ return cls()
+
+ # Create a model with 2 output layers that share the same inner layer.
+ inner_layer = InnerLayer()
+ outer_layer_1 = OuterLayer(inner_layer)
+ outer_layer_2 = OuterLayer(inner_layer)
+ input_ = keras.Input(shape=(1,))
+ model = keras.Model(
+ inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
+
+ # Changes to the shared layer should affect both outputs.
+ model.layers[1].inner_layer.v.assign(5)
+ self.assertAllEqual(model(1), [6.0, 6.0])
+ model.layers[1].inner_layer.v.assign(3)
+ self.assertAllEqual(model(1), [4.0, 4.0])
+
+ # After loading, changes to the shared layer should still affect both
+ # outputs.
+ def _do_assertions(loaded):
+ loaded.layers[1].inner_layer.v.assign(5)
+ self.assertAllEqual(loaded(1), [6.0, 6.0])
+ loaded.layers[1].inner_layer.v.assign(3)
+ self.assertAllEqual(loaded(1), [4.0, 4.0])
+ loaded.layers[2].inner_layer.v.assign(5)
+ self.assertAllEqual(loaded(1), [6.0, 6.0])
+ loaded.layers[2].inner_layer.v.assign(3)
+ self.assertAllEqual(loaded(1), [4.0, 4.0])
+
+ # We'd like to make sure we only attach shared object IDs when strictly
+ # necessary, so we'll recursively traverse the generated config to count
+ # whether we have the exact number we expect.
+ def _get_all_keys_recursive(dict_or_iterable):
+ if isinstance(dict_or_iterable, dict):
+ for key in dict_or_iterable.keys():
+ yield key
+ for key in _get_all_keys_recursive(dict_or_iterable.values()):
+ yield key
+ elif isinstance(dict_or_iterable, string_types):
+ return
+ else:
+ try:
+ for item in dict_or_iterable:
+ for key in _get_all_keys_recursive(item):
+ yield key
+ # Not an iterable or dictionary
+ except TypeError:
+ return
+
+ with generic_utils.CustomObjectScope({
+ 'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
+
+ # Test saving and loading to disk
+ save_format = testing_utils.get_save_format()
+ saved_model_dir = self._save_model_dir()
+ keras.models.save_model(model, saved_model_dir, save_format=save_format)
+ loaded = keras.models.load_model(saved_model_dir)
+ _do_assertions(loaded)
+
+ # Test recreating directly from config
+ config = model.get_config()
+ key_count = collections.Counter(_get_all_keys_recursive(config))
+ self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
+ loaded = keras.Model.from_config(config)
+ _do_assertions(loaded)
+
+ @combinations.generate(combinations.combine(mode=['eager']))
+ def test_shared_objects_wrapper(self):
+ """Tests that shared layers wrapped with `Wrapper` restore correctly."""
+ input_ = keras.Input(shape=(1,))
+ unwrapped = keras.layers.Layer(name='unwrapped')
+ wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
+ model = keras.Model(inputs=input_,
+ outputs=[unwrapped(input_), wrapped(input_)])
+
+ # Test recreating directly from config
+ config = model.get_config()
+ loaded = keras.Model.from_config(config)
+ self.assertIs(loaded.layers[1], loaded.layers[2].layer)
+
+ # Test saving and loading to disk
+ save_format = testing_utils.get_save_format()
+ saved_model_dir = self._save_model_dir()
+ keras.models.save_model(model, saved_model_dir, save_format=save_format)
+ loaded = keras.models.load_model(saved_model_dir)
+ self.assertIs(loaded.layers[1], loaded.layers[2].layer)
+
# Factory functions to create models that will be serialized inside a Network.
def _make_graph_network(input_size, output_size):
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index e2776bc..3f59a8e 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -46,7 +46,6 @@
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
metadata = dict(
- class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
trainable=self.obj.trainable,
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
@@ -56,7 +55,7 @@
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
)
- metadata.update(get_config(self.obj))
+ metadata.update(get_serialized(self.obj))
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure(
@@ -110,16 +109,12 @@
# TODO(kathywu): Move serialization utils (and related utils from
# generic_utils.py) to a separate file.
-def get_config(obj):
+def get_serialized(obj):
with generic_utils.skip_failed_serialization():
# Store the config dictionary, which may be used when reviving the object.
# When loading, the program will attempt to revive the object from config,
# and if that fails, the object will be revived from the SavedModel.
- config = generic_utils.serialize_keras_object(obj)['config']
-
- if config is not None:
- return {'config': config}
- return {}
+ return generic_utils.serialize_keras_object(obj)
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 217b124..fc34bf3 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -493,13 +493,15 @@
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
+ shared_object_id = metadata.get('shared_object_id')
must_restore_from_config = metadata.get('must_restore_from_config')
if not generic_utils.validate_config(config):
return None
try:
obj = layers_module.deserialize(
- generic_utils.serialize_keras_class_and_config(class_name, config))
+ generic_utils.serialize_keras_class_and_config(
+ class_name, config, shared_object_id=shared_object_id))
except ValueError:
if must_restore_from_config:
raise RuntimeError(
diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
index fda341d..e2b6d36 100644
--- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
@@ -36,7 +36,7 @@
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
dtype=self.obj.dtype)
- metadata.update(layer_serialization.get_config(self.obj))
+ metadata.update(layer_serialization.get_serialized(self.obj))
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index ecf3824..89aeaf4 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -24,8 +24,10 @@
import os
import re
import sys
+import threading
import time
import types as python_types
+import weakref
import numpy as np
import six
@@ -110,9 +112,235 @@
return _GLOBAL_CUSTOM_OBJECTS
-def serialize_keras_class_and_config(cls_name, cls_config):
+# Store a unique, per-object ID for shared objects.
+#
+# We store a unique ID for each object so that we may, at loading time,
+# re-create the network properly. Without this ID, we would have no way of
+# determining whether a config is a description of a new object that
+# should be created or is merely a reference to an already-created object.
+SHARED_OBJECT_KEY = 'shared_object_id'
+
+
+SHARED_OBJECT_DISABLED = threading.local()
+SHARED_OBJECT_LOADING = threading.local()
+SHARED_OBJECT_SAVING = threading.local()
+
+
+# Attributes on the threadlocal variable must be set per-thread, thus we
+# cannot initialize these globally. Instead, we have accessor functions with
+# default values.
+def _shared_object_disabled():
+ """Get whether shared object handling is disabled in a threadsafe manner."""
+ return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
+
+
+def _shared_object_loading_scope():
+ """Get the current shared object saving scope in a threadsafe manner."""
+ return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
+
+
+def _shared_object_saving_scope():
+ """Get the current shared object saving scope in a threadsafe manner."""
+ return getattr(SHARED_OBJECT_SAVING, 'scope', None)
+
+
+class DisableSharedObjectScope(object):
+ """A context manager for disabling handling of shared objects.
+
+ Disables shared object handling for both saving and loading.
+
+ Created primarily for use with `clone_model`, which does extra surgery that
+ is incompatible with shared objects.
+ """
+
+ def __enter__(self):
+ SHARED_OBJECT_DISABLED.disabled = True
+ self._orig_loading_scope = _shared_object_loading_scope()
+ self._orig_saving_scope = _shared_object_saving_scope()
+
+ def __exit__(self, *args, **kwargs):
+ SHARED_OBJECT_DISABLED.disabled = False
+ SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
+ SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
+
+
+class NoopLoadingScope(object):
+ """The default shared object loading scope. It does nothing.
+
+ Created to simplify serialization code that doesn't care about shared objects
+ (e.g. when serializing a single object).
+ """
+
+ def get(self, unused_object_id):
+ return None
+
+ def set(self, object_id, obj):
+ pass
+
+
+class SharedObjectLoadingScope(object):
+ """A context manager for keeping track of loaded objects.
+
+ During the deserialization process, we may come across objects that are
+ shared across multiple layers. In order to accurately restore the network
+ structure to its original state, `SharedObjectLoadingScope` allows us to
+ re-use shared objects rather than cloning them.
+ """
+
+ def __enter__(self):
+ if _shared_object_disabled():
+ return NoopLoadingScope()
+
+ global SHARED_OBJECT_LOADING
+ SHARED_OBJECT_LOADING.scope = self
+ self._obj_ids_to_obj = {}
+ return self
+
+ def get(self, object_id):
+ """Given a shared object ID, returns a previously instantiated object.
+
+ Args:
+ object_id: shared object ID to use when attempting to find already-loaded
+ object.
+
+ Returns:
+ The object, if we've seen this ID before. Else, `None`.
+ """
+ # Explicitly check for `None` internally to make external calling code a
+ # bit cleaner.
+ if object_id is None:
+ return
+ return self._obj_ids_to_obj.get(object_id)
+
+ def set(self, object_id, obj):
+ """Stores an instantiated object for future lookup and sharing."""
+ if object_id is None:
+ return
+ self._obj_ids_to_obj[object_id] = obj
+
+ def __exit__(self, *args, **kwargs):
+ global SHARED_OBJECT_LOADING
+ SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
+
+
+class SharedObjectConfig(dict):
+ """A configuration container that keeps track of references.
+
+ `SharedObjectConfig` will automatically attach a shared object ID to any
+ configs which are referenced more than once, allowing for proper shared
+ object reconstruction at load time.
+
+ In most cases, it would be more proper to subclass something like
+ `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
+ Unfortunately, python's json encoder does not support `Mapping`s. This is
+ important functionality to retain, since we are dealing with serialization.
+
+ We should be safe to subclass `dict` here, since we aren't actually
+ overriding any core methods, only augmenting with a new one for reference
+ counting.
+ """
+
+ def __init__(self, base_config, object_id, **kwargs):
+ self.ref_count = 1
+ self.object_id = object_id
+ super(SharedObjectConfig, self).__init__(base_config, **kwargs)
+
+ def increment_ref_count(self):
+ # As soon as we've seen the object more than once, we want to attach the
+ # shared object ID. This allows us to only attach the shared object ID when
+ # it's strictly necessary, making backwards compatibility breakage less
+ # likely.
+ if self.ref_count == 1:
+ self[SHARED_OBJECT_KEY] = self.object_id
+ self.ref_count += 1
+
+
+class SharedObjectSavingScope(object):
+ """Keeps track of shared object configs when serializing."""
+
+ def __enter__(self):
+ if _shared_object_disabled():
+ return None
+
+ global SHARED_OBJECT_SAVING
+
+ # Serialization can happen at a number of layers for a number of reasons.
+ # We may end up with a case where we're opening a saving scope within
+ # another saving scope. In that case, we'd like to use the outermost scope
+ # available and ignore inner scopes, since there is not (yet) a reasonable
+ # use case for having these nested and distinct.
+ if _shared_object_saving_scope() is not None:
+ self._passthrough = True
+ return _shared_object_saving_scope()
+ else:
+ self._passthrough = False
+
+ SHARED_OBJECT_SAVING.scope = self
+ self._shared_objects_config = weakref.WeakKeyDictionary()
+ self._next_id = 0
+ return self
+
+ def get_config(self, obj):
+ """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
+
+ Args:
+ obj: The object for which to retrieve the `SharedObjectConfig`.
+
+ Returns:
+ The SharedObjectConfig for a given object, if already seen. Else,
+ `None`.
+ """
+ try:
+ shared_object_config = self._shared_objects_config[obj]
+ except (TypeError, KeyError):
+ # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
+ # that has not overridden `__hash__`), a `TypeError` will be thrown.
+ # We'll just continue on without shared object support.
+ return None
+ shared_object_config.increment_ref_count()
+ return shared_object_config
+
+ def create_config(self, base_config, obj):
+ """Create a new SharedObjectConfig for a given object."""
+ shared_object_config = SharedObjectConfig(base_config, self._next_id)
+ self._next_id += 1
+ try:
+ self._shared_objects_config[obj] = shared_object_config
+ except TypeError:
+ # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
+ # that has not overridden `__hash__`), a `TypeError` will be thrown.
+ # We'll just continue on without shared object support.
+ pass
+ return shared_object_config
+
+ def __exit__(self, *args, **kwargs):
+ if not getattr(self, '_passthrough', False):
+ global SHARED_OBJECT_SAVING
+ SHARED_OBJECT_SAVING.scope = None
+
+
+def serialize_keras_class_and_config(
+ cls_name, cls_config, obj=None, shared_object_id=None):
"""Returns the serialization of the class with the given config."""
- return {'class_name': cls_name, 'config': cls_config}
+ base_config = {'class_name': cls_name, 'config': cls_config}
+
+ # We call `serialize_keras_class_and_config` for some branches of the load
+ # path. In that case, we may already have a shared object ID we'd like to
+ # retain.
+ if shared_object_id is not None:
+ base_config[SHARED_OBJECT_KEY] = shared_object_id
+
+ # If we have an active `SharedObjectSavingScope`, check whether we've already
+ # serialized this config. If so, just use that config. This will store an
+ # extra ID field in the config, allowing us to re-create the shared object
+ # relationship at load time.
+ if _shared_object_saving_scope() is not None and obj is not None:
+ shared_object_config = _shared_object_saving_scope().get_config(obj)
+ if shared_object_config is None:
+ return _shared_object_saving_scope().create_config(base_config, obj)
+ return shared_object_config
+
+ return base_config
@keras_export('keras.utils.register_keras_serializable')
@@ -234,7 +462,19 @@
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
- """Serialize a Keras object into a JSON-compatible representation."""
+ """Serialize a Keras object into a JSON-compatible representation.
+
+ Calls to `serialize_keras_object` while underneath the
+ `SharedObjectSavingScope` context manager will cause any objects re-used
+ across multiple layers to be saved with a special shared object ID. This
+ allows the network to be re-created properly during deserialization.
+
+ Args:
+ instance: The object to serialize.
+
+ Returns:
+ A dict-like, JSON-compatible representation of the object's config.
+ """
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
@@ -265,7 +505,8 @@
serialization_config[key] = item
name = get_registered_name(instance.__class__)
- return serialize_keras_class_and_config(name, serialization_config)
+ return serialize_keras_class_and_config(
+ name, serialization_config, instance)
if hasattr(instance, '__name__'):
return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
@@ -286,8 +527,9 @@
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
- if (not isinstance(config, dict) or 'class_name' not in config or
- 'config' not in config):
+ if (not isinstance(config, dict)
+ or 'class_name' not in config
+ or 'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
@@ -341,7 +583,24 @@
module_objects=None,
custom_objects=None,
printable_module_name='object'):
- """Turns the serialized form of a Keras object back into an actual object."""
+ """Turns the serialized form of a Keras object back into an actual object.
+
+ Calls to `deserialize_keras_object` while underneath the
+ `SharedObjectLoadingScope` context manager will cause any already-seen shared
+ objects to be returned as-is rather than creating a new object.
+
+ Args:
+ identifier: the serialized form of the object.
+ module_objects: A dictionary of custom objects to look the name up in.
+ Generally, module_objects is provided by midlevel library implementers.
+ custom_objects: A dictionary of custom objects to look the name up in.
+ Generally, custom_objects is provided by the user.
+ printable_module_name: A human-readable string representing the type of the
+ object. Printed in case of exception.
+
+ Returns:
+ The deserialized object.
+ """
if identifier is None:
return None
@@ -351,25 +610,39 @@
(cls, cls_config) = class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
+ # If this object has already been loaded (i.e. it's shared between multiple
+ # objects), return the already-loaded object.
+ shared_object_id = config.get(SHARED_OBJECT_KEY)
+ shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
+ if shared_object is not None:
+ return shared_object
+
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
- return cls.from_config(
+ deserialized_obj = cls.from_config(
cls_config,
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
- with CustomObjectScope(custom_objects):
- return cls.from_config(cls_config)
+ else:
+ with CustomObjectScope(custom_objects):
+ deserialized_obj = cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
with CustomObjectScope(custom_objects):
- return cls(**cls_config)
+ deserialized_obj = cls(**cls_config)
+
+ # Add object to shared objects, in case we find it referenced again.
+ _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
+
+ return deserialized_obj
+
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index 2dc2952..dd28b17 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -23,6 +23,7 @@
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import test
@@ -384,5 +385,63 @@
[None, None, None])
+# object() alone isn't compatible with WeakKeyDictionary, which we use to
+# track shared configs.
+class MaybeSharedObject(object):
+ pass
+
+
+class SharedObjectScopeTest(test.TestCase):
+
+ def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
+ with generic_utils.SharedObjectSavingScope() as scope:
+ single_object = MaybeSharedObject()
+ self.assertIsNone(scope.get_config(single_object))
+ single_object_config = scope.create_config({}, single_object)
+ self.assertIsNotNone(single_object_config)
+ self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
+ single_object_config)
+
+ def test_shared_object_saving_scope_shared_object_exports_id(self):
+ with generic_utils.SharedObjectSavingScope() as scope:
+ shared_object = MaybeSharedObject()
+ self.assertIsNone(scope.get_config(shared_object))
+ scope.create_config({}, shared_object)
+ first_object_config = scope.get_config(shared_object)
+ second_object_config = scope.get_config(shared_object)
+ self.assertIn(generic_utils.SHARED_OBJECT_KEY,
+ first_object_config)
+ self.assertIn(generic_utils.SHARED_OBJECT_KEY,
+ second_object_config)
+ self.assertIs(first_object_config, second_object_config)
+
+ def test_shared_object_loading_scope_noop(self):
+ # Test that, without a context manager scope, adding configs will do
+ # nothing.
+ obj_id = 1
+ obj = MaybeSharedObject()
+ generic_utils._shared_object_loading_scope().set(obj_id, obj)
+ self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
+
+ def test_shared_object_loading_scope_returns_shared_obj(self):
+ obj_id = 1
+ obj = MaybeSharedObject()
+ with generic_utils.SharedObjectLoadingScope() as scope:
+ scope.set(obj_id, obj)
+ self.assertIs(scope.get(obj_id), obj)
+
+ def test_nested_shared_object_saving_scopes(self):
+ my_obj = MaybeSharedObject()
+ with generic_utils.SharedObjectSavingScope() as scope_1:
+ scope_1.create_config({}, my_obj)
+ with generic_utils.SharedObjectSavingScope() as scope_2:
+ # Nesting saving scopes should return the original scope and should
+ # not clear any objects we're tracking.
+ self.assertIs(scope_1, scope_2)
+ self.assertIsNotNone(scope_2.get_config(my_obj))
+ self.assertIsNotNone(scope_1.get_config(my_obj))
+ self.assertIsNone(generic_utils._shared_object_saving_scope())
+
+
if __name__ == '__main__':
test.main()