Add support for shared objects in keras when (de)serializing.
Previously, objects shared between multiple layers would be duplicated when
saving and loading. This commit adds a unique ID for keras objects when
serializing that allows us to correctly create only a single instance of
shared objects when deserializing.
PiperOrigin-RevId: 353123479
Change-Id: I057ba61fe587d5cb97238fb31ca562dc5791cf88
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index a3aa265..743b4c0 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -671,13 +671,12 @@
Raises:
ValueError: In case of improperly formatted config dict.
"""
- with generic_utils.SharedObjectLoadingScope():
- input_tensors, output_tensors, created_layers = reconstruct_from_config(
- config, custom_objects)
- model = cls(inputs=input_tensors, outputs=output_tensors,
- name=config.get('name'))
- connect_ancillary_layers(model, created_layers)
- return model
+ input_tensors, output_tensors, created_layers = reconstruct_from_config(
+ config, custom_objects)
+ model = cls(inputs=input_tensors, outputs=output_tensors,
+ name=config.get('name'))
+ connect_ancillary_layers(model, created_layers)
+ return model
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
@@ -1347,23 +1346,21 @@
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
+ for layer in network.layers: # From the earliest layers on.
+ filtered_inbound_nodes = []
+ for original_node_index, node in enumerate(layer._inbound_nodes):
+ node_key = _make_node_key(layer.name, original_node_index)
+ if node_key in network._network_nodes and not node.is_input:
+ # The node is relevant to the model:
+ # add to filtered_inbound_nodes.
+ node_data = node.serialize(_make_node_key, node_conversion_map)
+ filtered_inbound_nodes.append(node_data)
- with generic_utils.SharedObjectSavingScope():
- for layer in network.layers: # From the earliest layers on.
- filtered_inbound_nodes = []
- for original_node_index, node in enumerate(layer._inbound_nodes):
- node_key = _make_node_key(layer.name, original_node_index)
- if node_key in network._network_nodes and not node.is_input:
- # The node is relevant to the model:
- # add to filtered_inbound_nodes.
- node_data = node.serialize(_make_node_key, node_conversion_map)
- filtered_inbound_nodes.append(node_data)
-
- layer_config = serialize_layer_fn(layer)
- layer_config['name'] = layer.name
- layer_config['inbound_nodes'] = filtered_inbound_nodes
- layer_configs.append(layer_config)
- config['layers'] = layer_configs
+ layer_config = serialize_layer_fn(layer)
+ layer_config['name'] = layer.name
+ layer_config['inbound_nodes'] = filtered_inbound_nodes
+ layer_configs.append(layer_config)
+ config['layers'] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 2262538..b16e0d6 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -393,10 +393,6 @@
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
- `clone_model` will not preserve the uniqueness of shared objects within the
- model (e.g. a single variable attached to two distinct layers will be
- restored as two separate variables).
-
Args:
model: Instance of `Model`
(could be a functional model or a Sequential model).
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index ef7f699..d4749fc 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -148,9 +148,8 @@
hdf5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer)
else:
- with generic_utils.SharedObjectSavingScope():
- saved_model_save.save(model, filepath, overwrite, include_optimizer,
- signatures, options, save_traces)
+ saved_model_save.save(model, filepath, overwrite, include_optimizer,
+ signatures, options, save_traces)
@keras_export('keras.models.load_model')
@@ -195,18 +194,17 @@
ImportError: if loading from an hdf5 file and h5py is not available.
IOError: In case of an invalid savefile.
"""
- with generic_utils.SharedObjectLoadingScope():
- with generic_utils.CustomObjectScope(custom_objects or {}):
- with load_context.load_context(options):
- if (h5py is not None and
- (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
- return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
- compile)
+ with generic_utils.CustomObjectScope(custom_objects or {}):
+ with load_context.load_context(options):
+ if (h5py is not None and
+ (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
+ return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
+ compile)
- filepath = path_to_string(filepath)
- if isinstance(filepath, six.string_types):
- loader_impl.parse_saved_model(filepath)
- return saved_model_load.load(filepath, compile, options)
+ filepath = path_to_string(filepath)
+ if isinstance(filepath, six.string_types):
+ loader_impl.parse_saved_model(filepath)
+ return saved_model_load.load(filepath, compile, options)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 20a779b..00c7bb2 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function
-import collections
import os
import shutil
import sys
@@ -26,14 +25,12 @@
from absl.testing import parameterized
import numpy as np
-from six import string_types
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import combinations
@@ -862,125 +859,6 @@
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
expected)
- @combinations.generate(combinations.combine(mode=['eager']))
- def test_shared_objects(self):
- class OuterLayer(keras.layers.Layer):
-
- def __init__(self, inner_layer):
- super(OuterLayer, self).__init__()
- self.inner_layer = inner_layer
-
- def call(self, inputs):
- return self.inner_layer(inputs)
-
- def get_config(self):
- return {
- 'inner_layer': generic_utils.serialize_keras_object(
- self.inner_layer)
- }
-
- @classmethod
- def from_config(cls, config):
- return cls(generic_utils.deserialize_keras_object(
- config['inner_layer']))
-
- class InnerLayer(keras.layers.Layer):
-
- def __init__(self):
- super(InnerLayer, self).__init__()
- self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
-
- def call(self, inputs):
- return self.v + inputs
-
- @classmethod
- def from_config(cls, config):
- return cls()
-
- # Create a model with 2 output layers that share the same inner layer.
- inner_layer = InnerLayer()
- outer_layer_1 = OuterLayer(inner_layer)
- outer_layer_2 = OuterLayer(inner_layer)
- input_ = keras.Input(shape=(1,))
- model = keras.Model(
- inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
-
- # Changes to the shared layer should affect both outputs.
- model.layers[1].inner_layer.v.assign(5)
- self.assertAllEqual(model(1), [6.0, 6.0])
- model.layers[1].inner_layer.v.assign(3)
- self.assertAllEqual(model(1), [4.0, 4.0])
-
- # After loading, changes to the shared layer should still affect both
- # outputs.
- def _do_assertions(loaded):
- loaded.layers[1].inner_layer.v.assign(5)
- self.assertAllEqual(loaded(1), [6.0, 6.0])
- loaded.layers[1].inner_layer.v.assign(3)
- self.assertAllEqual(loaded(1), [4.0, 4.0])
- loaded.layers[2].inner_layer.v.assign(5)
- self.assertAllEqual(loaded(1), [6.0, 6.0])
- loaded.layers[2].inner_layer.v.assign(3)
- self.assertAllEqual(loaded(1), [4.0, 4.0])
-
- # We'd like to make sure we only attach shared object IDs when strictly
- # necessary, so we'll recursively traverse the generated config to count
- # whether we have the exact number we expect.
- def _get_all_keys_recursive(dict_or_iterable):
- if isinstance(dict_or_iterable, dict):
- for key in dict_or_iterable.keys():
- yield key
- for key in _get_all_keys_recursive(dict_or_iterable.values()):
- yield key
- elif isinstance(dict_or_iterable, string_types):
- return
- else:
- try:
- for item in dict_or_iterable:
- for key in _get_all_keys_recursive(item):
- yield key
- # Not an iterable or dictionary
- except TypeError:
- return
-
- with generic_utils.CustomObjectScope({
- 'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
-
- # Test saving and loading to disk
- save_format = testing_utils.get_save_format()
- saved_model_dir = self._save_model_dir()
- keras.models.save_model(model, saved_model_dir, save_format=save_format)
- loaded = keras.models.load_model(saved_model_dir)
- _do_assertions(loaded)
-
- # Test recreating directly from config
- config = model.get_config()
- key_count = collections.Counter(_get_all_keys_recursive(config))
- self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
- loaded = keras.Model.from_config(config)
- _do_assertions(loaded)
-
- @combinations.generate(combinations.combine(mode=['eager']))
- def test_shared_objects_wrapper(self):
- """Tests that shared layers wrapped with `Wrapper` restore correctly."""
- input_ = keras.Input(shape=(1,))
- unwrapped = keras.layers.Layer(name='unwrapped')
- wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
- model = keras.Model(inputs=input_,
- outputs=[unwrapped(input_), wrapped(input_)])
-
- # Test recreating directly from config
- config = model.get_config()
- loaded = keras.Model.from_config(config)
- self.assertIs(loaded.layers[1], loaded.layers[2].layer)
-
- # Test saving and loading to disk
- save_format = testing_utils.get_save_format()
- saved_model_dir = self._save_model_dir()
- keras.models.save_model(model, saved_model_dir, save_format=save_format)
- loaded = keras.models.load_model(saved_model_dir)
- self.assertIs(loaded.layers[1], loaded.layers[2].layer)
-
# Factory functions to create models that will be serialized inside a Network.
def _make_graph_network(input_size, output_size):
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 3f59a8e..e2776bc 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -46,6 +46,7 @@
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
metadata = dict(
+ class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
trainable=self.obj.trainable,
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
@@ -55,7 +56,7 @@
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
)
- metadata.update(get_serialized(self.obj))
+ metadata.update(get_config(self.obj))
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure(
@@ -109,12 +110,16 @@
# TODO(kathywu): Move serialization utils (and related utils from
# generic_utils.py) to a separate file.
-def get_serialized(obj):
+def get_config(obj):
with generic_utils.skip_failed_serialization():
# Store the config dictionary, which may be used when reviving the object.
# When loading, the program will attempt to revive the object from config,
# and if that fails, the object will be revived from the SavedModel.
- return generic_utils.serialize_keras_object(obj)
+ config = generic_utils.serialize_keras_object(obj)['config']
+
+ if config is not None:
+ return {'config': config}
+ return {}
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 7f38f05..586394e 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -492,15 +492,13 @@
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
- shared_object_id = metadata.get('shared_object_id')
must_restore_from_config = metadata.get('must_restore_from_config')
if not generic_utils.validate_config(config):
return None
try:
obj = layers_module.deserialize(
- generic_utils.serialize_keras_class_and_config(
- class_name, config, shared_object_id=shared_object_id))
+ generic_utils.serialize_keras_class_and_config(class_name, config))
except ValueError:
if must_restore_from_config:
raise RuntimeError(
diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
index e2b6d36..fda341d 100644
--- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
@@ -36,7 +36,7 @@
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
dtype=self.obj.dtype)
- metadata.update(layer_serialization.get_serialized(self.obj))
+ metadata.update(layer_serialization.get_config(self.obj))
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index ea1c051..ecf3824 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -24,10 +24,8 @@
import os
import re
import sys
-import threading
import time
import types as python_types
-import weakref
import numpy as np
import six
@@ -112,205 +110,9 @@
return _GLOBAL_CUSTOM_OBJECTS
-# Store a unique, per-object ID for shared objects.
-#
-# We store a unique ID for each object so that we may, at loading time,
-# re-create the network properly. Without this ID, we would have no way of
-# determining whether a config is a description of a new object that
-# should be created or is merely a reference to an already-created object.
-SHARED_OBJECT_KEY = 'shared_object_id'
-
-
-class NoopLoadingScope(object):
- """The default shared object loading scope. It does nothing.
-
- Created to simplify serialization code that doesn't care about shared objects
- (e.g. when serializing a single object).
- """
-
- def get(self, unused_object_id):
- return None
-
- def set(self, object_id, obj):
- pass
-
-
-SHARED_OBJECT_LOADING = threading.local()
-
-
-def _shared_object_loading_scope():
- """Get the current shared object saving scope in a threadsafe manner.
-
- Attributes on the threadlocal variable must be set per-thread, thus we
- cannot initialize these globally.
-
- Returns:
- A SharedObjectLoadingScope or NoopLoadingScope object.
- """
- return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
-
-
-class SharedObjectLoadingScope(object):
- """A context manager for keeping track of loaded objects.
-
- During the deserialization process, we may come across objects that are
- shared across multiple layers. In order to accurately restore the network
- structure to its original state, `SharedObjectLoadingScope` allows us to
- re-use shared objects rather than cloning them.
- """
-
- def __enter__(self):
- global SHARED_OBJECT_LOADING
-
- SHARED_OBJECT_LOADING.scope = self
- self._obj_ids_to_obj = {}
- return self
-
- def get(self, object_id):
- """Given a shared object ID, returns a previously instantiated object.
-
- Args:
- object_id: shared object ID to use when attempting to find already-loaded
- object.
-
- Returns:
- The object, if we've seen this ID before. Else, `None`.
- """
- # Explicitly check for `None` internally to make external calling code a
- # bit cleaner.
- if object_id is None:
- return
- return self._obj_ids_to_obj.get(object_id)
-
- def set(self, object_id, obj):
- """Stores an instantiated object for future lookup and sharing."""
- if object_id is None:
- return
- self._obj_ids_to_obj[object_id] = obj
-
- def __exit__(self, *args, **kwargs):
- global SHARED_OBJECT_LOADING
- SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
-
-
-SHARED_OBJECT_SAVING = threading.local()
-
-
-def _shared_object_saving_scope():
- """Get the current shared object saving scope in a threadsafe manner.
-
- Attributes on the threadlocal variable must be set per-thread, thus we
- cannot initialize these globally.
-
- Returns:
- A SharedObjectSavingScope object or None.
- """
- return getattr(SHARED_OBJECT_SAVING, 'scope', None)
-
-
-class SharedObjectConfig(dict):
- """A configuration container that keeps track of references.
-
- `SharedObjectConfig` will automatically attach a shared object ID to any
- configs which are referenced more than once, allowing for proper shared
- object reconstruction at load time.
-
- In most cases, it would be more proper to subclass something like
- `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
- Unfortunately, python's json encoder does not support `Mapping`s. This is
- important functionality to retain, since we are dealing with serialization.
-
- We should be safe to subclass `dict` here, since we aren't actually
- overriding any core methods, only augmenting with a new one for reference
- counting.
- """
-
- def __init__(self, base_config, object_id, **kwargs):
- self.ref_count = 1
- self.object_id = object_id
- super(SharedObjectConfig, self).__init__(base_config, **kwargs)
-
- def increment_ref_count(self):
- # As soon as we've seen the object more than once, we want to attach the
- # shared object ID. This allows us to only attach the shared object ID when
- # it's strictly necessary, making backwards compatibility breakage less
- # likely.
- if self.ref_count == 1:
- self[SHARED_OBJECT_KEY] = self.object_id
- self.ref_count += 1
-
-
-class SharedObjectSavingScope(object):
- """Keeps track of shared object configs when serializing."""
-
- def __enter__(self):
- global SHARED_OBJECT_SAVING
-
- # Serialization can happen at a number of layers for a number of reasons.
- # We may end up with a case where we're opening a saving scope within
- # another saving scope. In that case, we'd like to use the outermost scope
- # available and ignore inner scopes, since there is not (yet) a reasonable
- # use case for having these nested and distinct.
- if _shared_object_saving_scope() is not None:
- self._passthrough = True
- return _shared_object_saving_scope()
- else:
- self._passthrough = False
-
- SHARED_OBJECT_SAVING.scope = self
- self._shared_objects_config = weakref.WeakKeyDictionary()
- self._next_id = 0
- return self
-
- def get_config(self, obj):
- """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
-
- Args:
- obj: The object for which to retrieve the `SharedObjectConfig`.
-
- Returns:
- The SharedObjectConfig for a given object, if already seen. Else,
- `None`.
- """
- if obj in self._shared_objects_config:
- shared_object_config = self._shared_objects_config[obj]
- shared_object_config.increment_ref_count()
- return shared_object_config
-
- def create_config(self, base_config, obj):
- shared_object_config = SharedObjectConfig(base_config, self._next_id)
- self._next_id += 1
- self._shared_objects_config[obj] = shared_object_config
- return shared_object_config
-
- def __exit__(self, *args, **kwargs):
- if not self._passthrough:
- global SHARED_OBJECT_SAVING
- SHARED_OBJECT_SAVING.scope = None
-
-
-def serialize_keras_class_and_config(
- cls_name, cls_config, obj=None, shared_object_id=None):
+def serialize_keras_class_and_config(cls_name, cls_config):
"""Returns the serialization of the class with the given config."""
- base_config = {'class_name': cls_name, 'config': cls_config}
-
- # We call `serialize_keras_class_and_config` for some branches of the load
- # path. In that case, we may already have a shared object ID we'd like to
- # retain.
- if shared_object_id is not None:
- base_config[SHARED_OBJECT_KEY] = shared_object_id
-
- # If we have an active `SharedObjectSavingScope`, check whether we've already
- # serialized this config. If so, just use that config. This will store an
- # extra ID field in the config, allowing us to re-create the shared object
- # relationship at load time.
- if _shared_object_saving_scope() is not None and obj is not None:
- shared_object_config = _shared_object_saving_scope().get_config(obj)
- if shared_object_config is None:
- return _shared_object_saving_scope().create_config(base_config, obj)
- return shared_object_config
-
- return base_config
+ return {'class_name': cls_name, 'config': cls_config}
@keras_export('keras.utils.register_keras_serializable')
@@ -432,19 +234,7 @@
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
- """Serialize a Keras object into a JSON-compatible representation.
-
- Calls to `serialize_keras_object` while underneath the
- `SharedObjectSavingScope` context manager will cause any objects re-used
- across multiple layers to be saved with a special shared object ID. This
- allows the network to be re-created properly during deserialization.
-
- Args:
- instance: The object to serialize.
-
- Returns:
- A dict-like, JSON-compatible representation of the object's config.
- """
+ """Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
@@ -475,8 +265,7 @@
serialization_config[key] = item
name = get_registered_name(instance.__class__)
- return serialize_keras_class_and_config(
- name, serialization_config, instance)
+ return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
@@ -497,9 +286,8 @@
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
- if (not isinstance(config, dict)
- or 'class_name' not in config
- or 'config' not in config):
+ if (not isinstance(config, dict) or 'class_name' not in config or
+ 'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
@@ -553,24 +341,7 @@
module_objects=None,
custom_objects=None,
printable_module_name='object'):
- """Turns the serialized form of a Keras object back into an actual object.
-
- Calls to `deserialize_keras_object` while underneath the
- `SharedObjectLoadingScope` context manager will cause any already-seen shared
- objects to be returned as-is rather than creating a new object.
-
- Args:
- identifier: the serialized form of the object.
- module_objects: A dictionary of custom objects to look the name up in.
- Generally, module_objects is provided by midlevel library implementers.
- custom_objects: A dictionary of custom objects to look the name up in.
- Generally, custom_objects is provided by the user.
- printable_module_name: A human-readable string representing the type of the
- object. Printed in case of exception.
-
- Returns:
- The deserialized object.
- """
+ """Turns the serialized form of a Keras object back into an actual object."""
if identifier is None:
return None
@@ -580,39 +351,25 @@
(cls, cls_config) = class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
- # If this object has already been loaded (i.e. it's shared between multiple
- # objects), return the already-loaded object.
- shared_object_id = config.get(SHARED_OBJECT_KEY)
- shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
- if shared_object is not None:
- return shared_object
-
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
- deserialized_obj = cls.from_config(
+ return cls.from_config(
cls_config,
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
- else:
- with CustomObjectScope(custom_objects):
- deserialized_obj = cls.from_config(cls_config)
+ with CustomObjectScope(custom_objects):
+ return cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
with CustomObjectScope(custom_objects):
- deserialized_obj = cls(**cls_config)
-
- # Add object to shared objects, in case we find it referenced again.
- _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
-
- return deserialized_obj
-
+ return cls(**cls_config)
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index dd28b17..2dc2952 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -23,7 +23,6 @@
import numpy as np
from tensorflow.python import keras
-from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import test
@@ -385,63 +384,5 @@
[None, None, None])
-# object() alone isn't compatible with WeakKeyDictionary, which we use to
-# track shared configs.
-class MaybeSharedObject(object):
- pass
-
-
-class SharedObjectScopeTest(test.TestCase):
-
- def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
- with generic_utils.SharedObjectSavingScope() as scope:
- single_object = MaybeSharedObject()
- self.assertIsNone(scope.get_config(single_object))
- single_object_config = scope.create_config({}, single_object)
- self.assertIsNotNone(single_object_config)
- self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
- single_object_config)
-
- def test_shared_object_saving_scope_shared_object_exports_id(self):
- with generic_utils.SharedObjectSavingScope() as scope:
- shared_object = MaybeSharedObject()
- self.assertIsNone(scope.get_config(shared_object))
- scope.create_config({}, shared_object)
- first_object_config = scope.get_config(shared_object)
- second_object_config = scope.get_config(shared_object)
- self.assertIn(generic_utils.SHARED_OBJECT_KEY,
- first_object_config)
- self.assertIn(generic_utils.SHARED_OBJECT_KEY,
- second_object_config)
- self.assertIs(first_object_config, second_object_config)
-
- def test_shared_object_loading_scope_noop(self):
- # Test that, without a context manager scope, adding configs will do
- # nothing.
- obj_id = 1
- obj = MaybeSharedObject()
- generic_utils._shared_object_loading_scope().set(obj_id, obj)
- self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
-
- def test_shared_object_loading_scope_returns_shared_obj(self):
- obj_id = 1
- obj = MaybeSharedObject()
- with generic_utils.SharedObjectLoadingScope() as scope:
- scope.set(obj_id, obj)
- self.assertIs(scope.get(obj_id), obj)
-
- def test_nested_shared_object_saving_scopes(self):
- my_obj = MaybeSharedObject()
- with generic_utils.SharedObjectSavingScope() as scope_1:
- scope_1.create_config({}, my_obj)
- with generic_utils.SharedObjectSavingScope() as scope_2:
- # Nesting saving scopes should return the original scope and should
- # not clear any objects we're tracking.
- self.assertIs(scope_1, scope_2)
- self.assertIsNotNone(scope_2.get_config(my_obj))
- self.assertIsNotNone(scope_1.get_config(my_obj))
- self.assertIsNone(generic_utils._shared_object_saving_scope())
-
-
if __name__ == '__main__':
test.main()