Add support for shared objects in keras when (de)serializing.

Previously, objects shared between multiple layers would be duplicated when
saving and loading.  This commit adds a unique ID for keras objects when
serializing that allows us to correctly create only a single instance of
shared objects when deserializing.

PiperOrigin-RevId: 353123479
Change-Id: I057ba61fe587d5cb97238fb31ca562dc5791cf88
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index a3aa265..743b4c0 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -671,13 +671,12 @@
     Raises:
         ValueError: In case of improperly formatted config dict.
     """
-    with generic_utils.SharedObjectLoadingScope():
-      input_tensors, output_tensors, created_layers = reconstruct_from_config(
-          config, custom_objects)
-      model = cls(inputs=input_tensors, outputs=output_tensors,
-                  name=config.get('name'))
-      connect_ancillary_layers(model, created_layers)
-      return model
+    input_tensors, output_tensors, created_layers = reconstruct_from_config(
+        config, custom_objects)
+    model = cls(inputs=input_tensors, outputs=output_tensors,
+                name=config.get('name'))
+    connect_ancillary_layers(model, created_layers)
+    return model
 
   def _validate_graph_inputs_and_outputs(self):
     """Validates the inputs and outputs of a Graph Network."""
@@ -1347,23 +1346,21 @@
         node_conversion_map[node_key] = kept_nodes
         kept_nodes += 1
   layer_configs = []
+  for layer in network.layers:  # From the earliest layers on.
+    filtered_inbound_nodes = []
+    for original_node_index, node in enumerate(layer._inbound_nodes):
+      node_key = _make_node_key(layer.name, original_node_index)
+      if node_key in network._network_nodes and not node.is_input:
+        # The node is relevant to the model:
+        # add to filtered_inbound_nodes.
+        node_data = node.serialize(_make_node_key, node_conversion_map)
+        filtered_inbound_nodes.append(node_data)
 
-  with generic_utils.SharedObjectSavingScope():
-    for layer in network.layers:  # From the earliest layers on.
-      filtered_inbound_nodes = []
-      for original_node_index, node in enumerate(layer._inbound_nodes):
-        node_key = _make_node_key(layer.name, original_node_index)
-        if node_key in network._network_nodes and not node.is_input:
-          # The node is relevant to the model:
-          # add to filtered_inbound_nodes.
-          node_data = node.serialize(_make_node_key, node_conversion_map)
-          filtered_inbound_nodes.append(node_data)
-
-      layer_config = serialize_layer_fn(layer)
-      layer_config['name'] = layer.name
-      layer_config['inbound_nodes'] = filtered_inbound_nodes
-      layer_configs.append(layer_config)
-    config['layers'] = layer_configs
+    layer_config = serialize_layer_fn(layer)
+    layer_config['name'] = layer.name
+    layer_config['inbound_nodes'] = filtered_inbound_nodes
+    layer_configs.append(layer_config)
+  config['layers'] = layer_configs
 
   # Gather info about inputs and outputs.
   model_inputs = []
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 2262538..b16e0d6 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -393,10 +393,6 @@
   except that it creates new layers (and thus new weights) instead
   of sharing the weights of the existing layers.
 
-  `clone_model` will not preserve the uniqueness of shared objects within the
-  model (e.g. a single variable attached to two distinct layers will be
-  restored as two separate variables).
-
   Args:
       model: Instance of `Model`
           (could be a functional model or a Sequential model).
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index ef7f699..d4749fc 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -148,9 +148,8 @@
     hdf5_format.save_model_to_hdf5(
         model, filepath, overwrite, include_optimizer)
   else:
-    with generic_utils.SharedObjectSavingScope():
-      saved_model_save.save(model, filepath, overwrite, include_optimizer,
-                            signatures, options, save_traces)
+    saved_model_save.save(model, filepath, overwrite, include_optimizer,
+                          signatures, options, save_traces)
 
 
 @keras_export('keras.models.load_model')
@@ -195,18 +194,17 @@
       ImportError: if loading from an hdf5 file and h5py is not available.
       IOError: In case of an invalid savefile.
   """
-  with generic_utils.SharedObjectLoadingScope():
-    with generic_utils.CustomObjectScope(custom_objects or {}):
-      with load_context.load_context(options):
-        if (h5py is not None and
-            (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
-          return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
-                                                  compile)
+  with generic_utils.CustomObjectScope(custom_objects or {}):
+    with load_context.load_context(options):
+      if (h5py is not None and
+          (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
+        return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
+                                                compile)
 
-        filepath = path_to_string(filepath)
-        if isinstance(filepath, six.string_types):
-          loader_impl.parse_saved_model(filepath)
-          return saved_model_load.load(filepath, compile, options)
+      filepath = path_to_string(filepath)
+      if isinstance(filepath, six.string_types):
+        loader_impl.parse_saved_model(filepath)
+        return saved_model_load.load(filepath, compile, options)
 
   raise IOError(
       'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 20a779b..00c7bb2 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import os
 import shutil
 import sys
@@ -26,14 +25,12 @@
 
 from absl.testing import parameterized
 import numpy as np
-from six import string_types
 
 from tensorflow.python import keras
 from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.feature_column import feature_column_lib
 from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.keras import combinations
@@ -862,125 +859,6 @@
     self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
                         expected)
 
-  @combinations.generate(combinations.combine(mode=['eager']))
-  def test_shared_objects(self):
-    class OuterLayer(keras.layers.Layer):
-
-      def __init__(self, inner_layer):
-        super(OuterLayer, self).__init__()
-        self.inner_layer = inner_layer
-
-      def call(self, inputs):
-        return self.inner_layer(inputs)
-
-      def get_config(self):
-        return {
-            'inner_layer': generic_utils.serialize_keras_object(
-                self.inner_layer)
-        }
-
-      @classmethod
-      def from_config(cls, config):
-        return cls(generic_utils.deserialize_keras_object(
-            config['inner_layer']))
-
-    class InnerLayer(keras.layers.Layer):
-
-      def __init__(self):
-        super(InnerLayer, self).__init__()
-        self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
-
-      def call(self, inputs):
-        return self.v + inputs
-
-      @classmethod
-      def from_config(cls, config):
-        return cls()
-
-    # Create a model with 2 output layers that share the same inner layer.
-    inner_layer = InnerLayer()
-    outer_layer_1 = OuterLayer(inner_layer)
-    outer_layer_2 = OuterLayer(inner_layer)
-    input_ = keras.Input(shape=(1,))
-    model = keras.Model(
-        inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
-
-    # Changes to the shared layer should affect both outputs.
-    model.layers[1].inner_layer.v.assign(5)
-    self.assertAllEqual(model(1), [6.0, 6.0])
-    model.layers[1].inner_layer.v.assign(3)
-    self.assertAllEqual(model(1), [4.0, 4.0])
-
-    # After loading, changes to the shared layer should still affect both
-    # outputs.
-    def _do_assertions(loaded):
-      loaded.layers[1].inner_layer.v.assign(5)
-      self.assertAllEqual(loaded(1), [6.0, 6.0])
-      loaded.layers[1].inner_layer.v.assign(3)
-      self.assertAllEqual(loaded(1), [4.0, 4.0])
-      loaded.layers[2].inner_layer.v.assign(5)
-      self.assertAllEqual(loaded(1), [6.0, 6.0])
-      loaded.layers[2].inner_layer.v.assign(3)
-      self.assertAllEqual(loaded(1), [4.0, 4.0])
-
-    # We'd like to make sure we only attach shared object IDs when strictly
-    # necessary, so we'll recursively traverse the generated config to count
-    # whether we have the exact number we expect.
-    def _get_all_keys_recursive(dict_or_iterable):
-      if isinstance(dict_or_iterable, dict):
-        for key in dict_or_iterable.keys():
-          yield key
-        for key in _get_all_keys_recursive(dict_or_iterable.values()):
-          yield key
-      elif isinstance(dict_or_iterable, string_types):
-        return
-      else:
-        try:
-          for item in dict_or_iterable:
-            for key in _get_all_keys_recursive(item):
-              yield key
-        # Not an iterable or dictionary
-        except TypeError:
-          return
-
-    with generic_utils.CustomObjectScope({
-        'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
-
-      # Test saving and loading to disk
-      save_format = testing_utils.get_save_format()
-      saved_model_dir = self._save_model_dir()
-      keras.models.save_model(model, saved_model_dir, save_format=save_format)
-      loaded = keras.models.load_model(saved_model_dir)
-      _do_assertions(loaded)
-
-      # Test recreating directly from config
-      config = model.get_config()
-      key_count = collections.Counter(_get_all_keys_recursive(config))
-      self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
-      loaded = keras.Model.from_config(config)
-      _do_assertions(loaded)
-
-  @combinations.generate(combinations.combine(mode=['eager']))
-  def test_shared_objects_wrapper(self):
-    """Tests that shared layers wrapped with `Wrapper` restore correctly."""
-    input_ = keras.Input(shape=(1,))
-    unwrapped = keras.layers.Layer(name='unwrapped')
-    wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
-    model = keras.Model(inputs=input_,
-                        outputs=[unwrapped(input_), wrapped(input_)])
-
-    # Test recreating directly from config
-    config = model.get_config()
-    loaded = keras.Model.from_config(config)
-    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
-
-    # Test saving and loading to disk
-    save_format = testing_utils.get_save_format()
-    saved_model_dir = self._save_model_dir()
-    keras.models.save_model(model, saved_model_dir, save_format=save_format)
-    loaded = keras.models.load_model(saved_model_dir)
-    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
-
 
 # Factory functions to create models that will be serialized inside a Network.
 def _make_graph_network(input_size, output_size):
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 3f59a8e..e2776bc 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -46,6 +46,7 @@
     # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
     # the python config serialization has caught up.
     metadata = dict(
+        class_name=generic_utils.get_registered_name(type(self.obj)),
         name=self.obj.name,
         trainable=self.obj.trainable,
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
@@ -55,7 +56,7 @@
         must_restore_from_config=self.obj._must_restore_from_config,  # pylint: disable=protected-access
     )
 
-    metadata.update(get_serialized(self.obj))
+    metadata.update(get_config(self.obj))
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
@@ -109,12 +110,16 @@
 
 # TODO(kathywu): Move serialization utils (and related utils from
 # generic_utils.py) to a separate file.
-def get_serialized(obj):
+def get_config(obj):
   with generic_utils.skip_failed_serialization():
     # Store the config dictionary, which may be used when reviving the object.
     # When loading, the program will attempt to revive the object from config,
     # and if that fails, the object will be revived from the SavedModel.
-    return generic_utils.serialize_keras_object(obj)
+    config = generic_utils.serialize_keras_object(obj)['config']
+
+  if config is not None:
+    return {'config': config}
+  return {}
 
 
 class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 7f38f05..586394e 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -492,15 +492,13 @@
     #       found.
     class_name = metadata.get('class_name')
     config = metadata.get('config')
-    shared_object_id = metadata.get('shared_object_id')
     must_restore_from_config = metadata.get('must_restore_from_config')
     if not generic_utils.validate_config(config):
       return None
 
     try:
       obj = layers_module.deserialize(
-          generic_utils.serialize_keras_class_and_config(
-              class_name, config, shared_object_id=shared_object_id))
+          generic_utils.serialize_keras_class_and_config(class_name, config))
     except ValueError:
       if must_restore_from_config:
         raise RuntimeError(
diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
index e2b6d36..fda341d 100644
--- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
@@ -36,7 +36,7 @@
         class_name=generic_utils.get_registered_name(type(self.obj)),
         name=self.obj.name,
         dtype=self.obj.dtype)
-    metadata.update(layer_serialization.get_serialized(self.obj))
+    metadata.update(layer_serialization.get_config(self.obj))
     if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
       metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
     return metadata
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index ea1c051..ecf3824 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -24,10 +24,8 @@
 import os
 import re
 import sys
-import threading
 import time
 import types as python_types
-import weakref
 
 import numpy as np
 import six
@@ -112,205 +110,9 @@
   return _GLOBAL_CUSTOM_OBJECTS
 
 
-# Store a unique, per-object ID for shared objects.
-#
-# We store a unique ID for each object so that we may, at loading time,
-# re-create the network properly.  Without this ID, we would have no way of
-# determining whether a config is a description of a new object that
-# should be created or is merely a reference to an already-created object.
-SHARED_OBJECT_KEY = 'shared_object_id'
-
-
-class NoopLoadingScope(object):
-  """The default shared object loading scope. It does nothing.
-
-  Created to simplify serialization code that doesn't care about shared objects
-  (e.g. when serializing a single object).
-  """
-
-  def get(self, unused_object_id):
-    return None
-
-  def set(self, object_id, obj):
-    pass
-
-
-SHARED_OBJECT_LOADING = threading.local()
-
-
-def _shared_object_loading_scope():
-  """Get the current shared object saving scope in a threadsafe manner.
-
-  Attributes on the threadlocal variable must be set per-thread, thus we
-  cannot initialize these globally.
-
-  Returns:
-    A SharedObjectLoadingScope or NoopLoadingScope object.
-  """
-  return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
-
-
-class SharedObjectLoadingScope(object):
-  """A context manager for keeping track of loaded objects.
-
-  During the deserialization process, we may come across objects that are
-  shared across multiple layers. In order to accurately restore the network
-  structure to its original state, `SharedObjectLoadingScope` allows us to
-  re-use shared objects rather than cloning them.
-  """
-
-  def __enter__(self):
-    global SHARED_OBJECT_LOADING
-
-    SHARED_OBJECT_LOADING.scope = self
-    self._obj_ids_to_obj = {}
-    return self
-
-  def get(self, object_id):
-    """Given a shared object ID, returns a previously instantiated object.
-
-    Args:
-      object_id: shared object ID to use when attempting to find already-loaded
-        object.
-
-    Returns:
-      The object, if we've seen this ID before. Else, `None`.
-    """
-    # Explicitly check for `None` internally to make external calling code a
-    # bit cleaner.
-    if object_id is None:
-      return
-    return self._obj_ids_to_obj.get(object_id)
-
-  def set(self, object_id, obj):
-    """Stores an instantiated object for future lookup and sharing."""
-    if object_id is None:
-      return
-    self._obj_ids_to_obj[object_id] = obj
-
-  def __exit__(self, *args, **kwargs):
-    global SHARED_OBJECT_LOADING
-    SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
-
-
-SHARED_OBJECT_SAVING = threading.local()
-
-
-def _shared_object_saving_scope():
-  """Get the current shared object saving scope in a threadsafe manner.
-
-  Attributes on the threadlocal variable must be set per-thread, thus we
-  cannot initialize these globally.
-
-  Returns:
-    A SharedObjectSavingScope object or None.
-  """
-  return getattr(SHARED_OBJECT_SAVING, 'scope', None)
-
-
-class SharedObjectConfig(dict):
-  """A configuration container that keeps track of references.
-
-  `SharedObjectConfig` will automatically attach a shared object ID to any
-  configs which are referenced more than once, allowing for proper shared
-  object reconstruction at load time.
-
-  In most cases, it would be more proper to subclass something like
-  `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
-  Unfortunately, python's json encoder does not support `Mapping`s. This is
-  important functionality to retain, since we are dealing with serialization.
-
-  We should be safe to subclass `dict` here, since we aren't actually
-  overriding any core methods, only augmenting with a new one for reference
-  counting.
-  """
-
-  def __init__(self, base_config, object_id, **kwargs):
-    self.ref_count = 1
-    self.object_id = object_id
-    super(SharedObjectConfig, self).__init__(base_config, **kwargs)
-
-  def increment_ref_count(self):
-    # As soon as we've seen the object more than once, we want to attach the
-    # shared object ID. This allows us to only attach the shared object ID when
-    # it's strictly necessary, making backwards compatibility breakage less
-    # likely.
-    if self.ref_count == 1:
-      self[SHARED_OBJECT_KEY] = self.object_id
-    self.ref_count += 1
-
-
-class SharedObjectSavingScope(object):
-  """Keeps track of shared object configs when serializing."""
-
-  def __enter__(self):
-    global SHARED_OBJECT_SAVING
-
-    # Serialization can happen at a number of layers for a number of reasons.
-    # We may end up with a case where we're opening a saving scope within
-    # another saving scope. In that case, we'd like to use the outermost scope
-    # available and ignore inner scopes, since there is not (yet) a reasonable
-    # use case for having these nested and distinct.
-    if _shared_object_saving_scope() is not None:
-      self._passthrough = True
-      return _shared_object_saving_scope()
-    else:
-      self._passthrough = False
-
-    SHARED_OBJECT_SAVING.scope = self
-    self._shared_objects_config = weakref.WeakKeyDictionary()
-    self._next_id = 0
-    return self
-
-  def get_config(self, obj):
-    """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
-
-    Args:
-      obj: The object for which to retrieve the `SharedObjectConfig`.
-
-    Returns:
-      The SharedObjectConfig for a given object, if already seen. Else,
-        `None`.
-    """
-    if obj in self._shared_objects_config:
-      shared_object_config = self._shared_objects_config[obj]
-      shared_object_config.increment_ref_count()
-      return shared_object_config
-
-  def create_config(self, base_config, obj):
-    shared_object_config = SharedObjectConfig(base_config, self._next_id)
-    self._next_id += 1
-    self._shared_objects_config[obj] = shared_object_config
-    return shared_object_config
-
-  def __exit__(self, *args, **kwargs):
-    if not self._passthrough:
-      global SHARED_OBJECT_SAVING
-      SHARED_OBJECT_SAVING.scope = None
-
-
-def serialize_keras_class_and_config(
-    cls_name, cls_config, obj=None, shared_object_id=None):
+def serialize_keras_class_and_config(cls_name, cls_config):
   """Returns the serialization of the class with the given config."""
-  base_config = {'class_name': cls_name, 'config': cls_config}
-
-  # We call `serialize_keras_class_and_config` for some branches of the load
-  # path. In that case, we may already have a shared object ID we'd like to
-  # retain.
-  if shared_object_id is not None:
-    base_config[SHARED_OBJECT_KEY] = shared_object_id
-
-  # If we have an active `SharedObjectSavingScope`, check whether we've already
-  # serialized this config. If so, just use that config. This will store an
-  # extra ID field in the config, allowing us to re-create the shared object
-  # relationship at load time.
-  if _shared_object_saving_scope() is not None and obj is not None:
-    shared_object_config = _shared_object_saving_scope().get_config(obj)
-    if shared_object_config is None:
-      return _shared_object_saving_scope().create_config(base_config, obj)
-    return shared_object_config
-
-  return base_config
+  return {'class_name': cls_name, 'config': cls_config}
 
 
 @keras_export('keras.utils.register_keras_serializable')
@@ -432,19 +234,7 @@
 
 @keras_export('keras.utils.serialize_keras_object')
 def serialize_keras_object(instance):
-  """Serialize a Keras object into a JSON-compatible representation.
-
-  Calls to `serialize_keras_object` while underneath the
-  `SharedObjectSavingScope` context manager will cause any objects re-used
-  across multiple layers to be saved with a special shared object ID. This
-  allows the network to be re-created properly during deserialization.
-
-  Args:
-    instance: The object to serialize.
-
-  Returns:
-    A dict-like, JSON-compatible representation of the object's config.
-  """
+  """Serialize a Keras object into a JSON-compatible representation."""
   _, instance = tf_decorator.unwrap(instance)
   if instance is None:
     return None
@@ -475,8 +265,7 @@
         serialization_config[key] = item
 
     name = get_registered_name(instance.__class__)
-    return serialize_keras_class_and_config(
-        name, serialization_config, instance)
+    return serialize_keras_class_and_config(name, serialization_config)
   if hasattr(instance, '__name__'):
     return get_registered_name(instance)
   raise ValueError('Cannot serialize', instance)
@@ -497,9 +286,8 @@
     custom_objects=None,
     printable_module_name='object'):
   """Returns the class name and config for a serialized keras object."""
-  if (not isinstance(config, dict)
-      or 'class_name' not in config
-      or 'config' not in config):
+  if (not isinstance(config, dict) or 'class_name' not in config or
+      'config' not in config):
     raise ValueError('Improper config format: ' + str(config))
 
   class_name = config['class_name']
@@ -553,24 +341,7 @@
                              module_objects=None,
                              custom_objects=None,
                              printable_module_name='object'):
-  """Turns the serialized form of a Keras object back into an actual object.
-
-  Calls to `deserialize_keras_object` while underneath the
-  `SharedObjectLoadingScope` context manager will cause any already-seen shared
-  objects to be returned as-is rather than creating a new object.
-
-  Args:
-    identifier: the serialized form of the object.
-    module_objects: A dictionary of custom objects to look the name up in.
-      Generally, module_objects is provided by midlevel library implementers.
-    custom_objects: A dictionary of custom objects to look the name up in.
-      Generally, custom_objects is provided by the user.
-    printable_module_name: A human-readable string representing the type of the
-      object. Printed in case of exception.
-
-  Returns:
-    The deserialized object.
-  """
+  """Turns the serialized form of a Keras object back into an actual object."""
   if identifier is None:
     return None
 
@@ -580,39 +351,25 @@
     (cls, cls_config) = class_and_config_for_serialized_keras_object(
         config, module_objects, custom_objects, printable_module_name)
 
-    # If this object has already been loaded (i.e. it's shared between multiple
-    # objects), return the already-loaded object.
-    shared_object_id = config.get(SHARED_OBJECT_KEY)
-    shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
-    if shared_object is not None:
-      return shared_object
-
     if hasattr(cls, 'from_config'):
       arg_spec = tf_inspect.getfullargspec(cls.from_config)
       custom_objects = custom_objects or {}
 
       if 'custom_objects' in arg_spec.args:
-        deserialized_obj = cls.from_config(
+        return cls.from_config(
             cls_config,
             custom_objects=dict(
                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                 list(custom_objects.items())))
-      else:
-        with CustomObjectScope(custom_objects):
-          deserialized_obj = cls.from_config(cls_config)
+      with CustomObjectScope(custom_objects):
+        return cls.from_config(cls_config)
     else:
       # Then `cls` may be a function returning a class.
       # in this case by convention `config` holds
       # the kwargs of the function.
       custom_objects = custom_objects or {}
       with CustomObjectScope(custom_objects):
-        deserialized_obj = cls(**cls_config)
-
-    # Add object to shared objects, in case we find it referenced again.
-    _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
-
-    return deserialized_obj
-
+        return cls(**cls_config)
   elif isinstance(identifier, six.string_types):
     object_name = identifier
     if custom_objects and object_name in custom_objects:
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index dd28b17..2dc2952 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -23,7 +23,6 @@
 import numpy as np
 
 from tensorflow.python import keras
-from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.platform import test
 
 
@@ -385,63 +384,5 @@
         [None, None, None])
 
 
-# object() alone isn't compatible with WeakKeyDictionary, which we use to
-# track shared configs.
-class MaybeSharedObject(object):
-  pass
-
-
-class SharedObjectScopeTest(test.TestCase):
-
-  def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
-    with generic_utils.SharedObjectSavingScope() as scope:
-      single_object = MaybeSharedObject()
-      self.assertIsNone(scope.get_config(single_object))
-      single_object_config = scope.create_config({}, single_object)
-      self.assertIsNotNone(single_object_config)
-      self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
-                       single_object_config)
-
-  def test_shared_object_saving_scope_shared_object_exports_id(self):
-    with generic_utils.SharedObjectSavingScope() as scope:
-      shared_object = MaybeSharedObject()
-      self.assertIsNone(scope.get_config(shared_object))
-      scope.create_config({}, shared_object)
-      first_object_config = scope.get_config(shared_object)
-      second_object_config = scope.get_config(shared_object)
-      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
-                    first_object_config)
-      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
-                    second_object_config)
-      self.assertIs(first_object_config, second_object_config)
-
-  def test_shared_object_loading_scope_noop(self):
-    # Test that, without a context manager scope, adding configs will do
-    # nothing.
-    obj_id = 1
-    obj = MaybeSharedObject()
-    generic_utils._shared_object_loading_scope().set(obj_id, obj)
-    self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
-
-  def test_shared_object_loading_scope_returns_shared_obj(self):
-    obj_id = 1
-    obj = MaybeSharedObject()
-    with generic_utils.SharedObjectLoadingScope() as scope:
-      scope.set(obj_id, obj)
-      self.assertIs(scope.get(obj_id), obj)
-
-  def test_nested_shared_object_saving_scopes(self):
-    my_obj = MaybeSharedObject()
-    with generic_utils.SharedObjectSavingScope() as scope_1:
-      scope_1.create_config({}, my_obj)
-      with generic_utils.SharedObjectSavingScope() as scope_2:
-        # Nesting saving scopes should return the original scope and should
-        # not clear any objects we're tracking.
-        self.assertIs(scope_1, scope_2)
-        self.assertIsNotNone(scope_2.get_config(my_obj))
-      self.assertIsNotNone(scope_1.get_config(my_obj))
-    self.assertIsNone(generic_utils._shared_object_saving_scope())
-
-
 if __name__ == '__main__':
   test.main()