Add support for shared objects in keras when (de)serializing.

Previously, objects shared between multiple layers would be duplicated when
saving and loading.  This commit adds a unique ID for keras objects when
serializing that allows us to correctly create only a single instance of
shared objects when deserializing.

PiperOrigin-RevId: 354996891
Change-Id: Idfd055f430c7ea7e25c459ed7715a370d7a632c9
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 743b4c0..a3aa265 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -671,12 +671,13 @@
     Raises:
         ValueError: In case of improperly formatted config dict.
     """
-    input_tensors, output_tensors, created_layers = reconstruct_from_config(
-        config, custom_objects)
-    model = cls(inputs=input_tensors, outputs=output_tensors,
-                name=config.get('name'))
-    connect_ancillary_layers(model, created_layers)
-    return model
+    with generic_utils.SharedObjectLoadingScope():
+      input_tensors, output_tensors, created_layers = reconstruct_from_config(
+          config, custom_objects)
+      model = cls(inputs=input_tensors, outputs=output_tensors,
+                  name=config.get('name'))
+      connect_ancillary_layers(model, created_layers)
+      return model
 
   def _validate_graph_inputs_and_outputs(self):
     """Validates the inputs and outputs of a Graph Network."""
@@ -1346,21 +1347,23 @@
         node_conversion_map[node_key] = kept_nodes
         kept_nodes += 1
   layer_configs = []
-  for layer in network.layers:  # From the earliest layers on.
-    filtered_inbound_nodes = []
-    for original_node_index, node in enumerate(layer._inbound_nodes):
-      node_key = _make_node_key(layer.name, original_node_index)
-      if node_key in network._network_nodes and not node.is_input:
-        # The node is relevant to the model:
-        # add to filtered_inbound_nodes.
-        node_data = node.serialize(_make_node_key, node_conversion_map)
-        filtered_inbound_nodes.append(node_data)
 
-    layer_config = serialize_layer_fn(layer)
-    layer_config['name'] = layer.name
-    layer_config['inbound_nodes'] = filtered_inbound_nodes
-    layer_configs.append(layer_config)
-  config['layers'] = layer_configs
+  with generic_utils.SharedObjectSavingScope():
+    for layer in network.layers:  # From the earliest layers on.
+      filtered_inbound_nodes = []
+      for original_node_index, node in enumerate(layer._inbound_nodes):
+        node_key = _make_node_key(layer.name, original_node_index)
+        if node_key in network._network_nodes and not node.is_input:
+          # The node is relevant to the model:
+          # add to filtered_inbound_nodes.
+          node_data = node.serialize(_make_node_key, node_conversion_map)
+          filtered_inbound_nodes.append(node_data)
+
+      layer_config = serialize_layer_fn(layer)
+      layer_config['name'] = layer.name
+      layer_config['inbound_nodes'] = filtered_inbound_nodes
+      layer_configs.append(layer_config)
+    config['layers'] = layer_configs
 
   # Gather info about inputs and outputs.
   model_inputs = []
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index b16e0d6..0b19f4e 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -393,6 +393,10 @@
   except that it creates new layers (and thus new weights) instead
   of sharing the weights of the existing layers.
 
+  `clone_model` will not preserve the uniqueness of shared objects within the
+  model (e.g. a single variable attached to two distinct layers will be
+  restored as two separate variables).
+
   Args:
       model: Instance of `Model`
           (could be a functional model or a Sequential model).
@@ -420,15 +424,16 @@
   Raises:
       ValueError: in case of invalid `model` argument value.
   """
-  if clone_function is None:
-    clone_function = _clone_layer
+  with generic_utils.DisableSharedObjectScope():
+    if clone_function is None:
+      clone_function = _clone_layer
 
-  if isinstance(model, Sequential):
-    return _clone_sequential_model(
-        model, input_tensors=input_tensors, layer_fn=clone_function)
-  else:
-    return _clone_functional_model(
-        model, input_tensors=input_tensors, layer_fn=clone_function)
+    if isinstance(model, Sequential):
+      return _clone_sequential_model(
+          model, input_tensors=input_tensors, layer_fn=clone_function)
+    else:
+      return _clone_functional_model(
+          model, input_tensors=input_tensors, layer_fn=clone_function)
 
 
 # "Clone" a subclassed model by reseting all of the attributes.
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 0ece5ac..12d1c39f 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -245,6 +245,28 @@
     loss = model.train_on_batch(x, y)
     self.assertEqual(float(loss), 0.)
 
+  def test_clone_rnn(self):
+    # Test cloning a model with multiple cells in an RNN.  This exercises a
+    # few "fancier" features such as the `Bidrectional` wrapper and
+    # `StackedRNNCells` under the hood.
+    inputs = keras.Input(shape=(3, 3))
+    cells = [
+        keras.layers.LSTMCell(
+            units=32,
+            enable_caching_device=True,
+            implementation=2,
+            activation='relu')]
+    rnn = keras.layers.RNN(cells, return_sequences=True)
+    outputs = keras.layers.Bidirectional(rnn)(inputs)
+    outputs = keras.layers.Dense(
+        12, activation='softmax', name='scores')(outputs)
+    model = keras.Model(inputs=inputs, outputs=outputs)
+    model.compile(
+        loss=keras.losses.CategoricalCrossentropy(),
+        optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.01),
+        metrics=['accuracy'])
+    keras.models.clone_model(model)
+
   def test_model_cloning_invalid_use_cases(self):
     seq_model = keras.models.Sequential()
     seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index d4749fc..ef7f699 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -148,8 +148,9 @@
     hdf5_format.save_model_to_hdf5(
         model, filepath, overwrite, include_optimizer)
   else:
-    saved_model_save.save(model, filepath, overwrite, include_optimizer,
-                          signatures, options, save_traces)
+    with generic_utils.SharedObjectSavingScope():
+      saved_model_save.save(model, filepath, overwrite, include_optimizer,
+                            signatures, options, save_traces)
 
 
 @keras_export('keras.models.load_model')
@@ -194,17 +195,18 @@
       ImportError: if loading from an hdf5 file and h5py is not available.
       IOError: In case of an invalid savefile.
   """
-  with generic_utils.CustomObjectScope(custom_objects or {}):
-    with load_context.load_context(options):
-      if (h5py is not None and
-          (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
-        return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
-                                                compile)
+  with generic_utils.SharedObjectLoadingScope():
+    with generic_utils.CustomObjectScope(custom_objects or {}):
+      with load_context.load_context(options):
+        if (h5py is not None and
+            (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
+          return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
+                                                  compile)
 
-      filepath = path_to_string(filepath)
-      if isinstance(filepath, six.string_types):
-        loader_impl.parse_saved_model(filepath)
-        return saved_model_load.load(filepath, compile, options)
+        filepath = path_to_string(filepath)
+        if isinstance(filepath, six.string_types):
+          loader_impl.parse_saved_model(filepath)
+          return saved_model_load.load(filepath, compile, options)
 
   raise IOError(
       'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 00c7bb2..20a779b 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import os
 import shutil
 import sys
@@ -25,12 +26,14 @@
 
 from absl.testing import parameterized
 import numpy as np
+from six import string_types
 
 from tensorflow.python import keras
 from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.feature_column import feature_column_lib
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.keras import combinations
@@ -859,6 +862,125 @@
     self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
                         expected)
 
+  @combinations.generate(combinations.combine(mode=['eager']))
+  def test_shared_objects(self):
+    class OuterLayer(keras.layers.Layer):
+
+      def __init__(self, inner_layer):
+        super(OuterLayer, self).__init__()
+        self.inner_layer = inner_layer
+
+      def call(self, inputs):
+        return self.inner_layer(inputs)
+
+      def get_config(self):
+        return {
+            'inner_layer': generic_utils.serialize_keras_object(
+                self.inner_layer)
+        }
+
+      @classmethod
+      def from_config(cls, config):
+        return cls(generic_utils.deserialize_keras_object(
+            config['inner_layer']))
+
+    class InnerLayer(keras.layers.Layer):
+
+      def __init__(self):
+        super(InnerLayer, self).__init__()
+        self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
+
+      def call(self, inputs):
+        return self.v + inputs
+
+      @classmethod
+      def from_config(cls, config):
+        return cls()
+
+    # Create a model with 2 output layers that share the same inner layer.
+    inner_layer = InnerLayer()
+    outer_layer_1 = OuterLayer(inner_layer)
+    outer_layer_2 = OuterLayer(inner_layer)
+    input_ = keras.Input(shape=(1,))
+    model = keras.Model(
+        inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
+
+    # Changes to the shared layer should affect both outputs.
+    model.layers[1].inner_layer.v.assign(5)
+    self.assertAllEqual(model(1), [6.0, 6.0])
+    model.layers[1].inner_layer.v.assign(3)
+    self.assertAllEqual(model(1), [4.0, 4.0])
+
+    # After loading, changes to the shared layer should still affect both
+    # outputs.
+    def _do_assertions(loaded):
+      loaded.layers[1].inner_layer.v.assign(5)
+      self.assertAllEqual(loaded(1), [6.0, 6.0])
+      loaded.layers[1].inner_layer.v.assign(3)
+      self.assertAllEqual(loaded(1), [4.0, 4.0])
+      loaded.layers[2].inner_layer.v.assign(5)
+      self.assertAllEqual(loaded(1), [6.0, 6.0])
+      loaded.layers[2].inner_layer.v.assign(3)
+      self.assertAllEqual(loaded(1), [4.0, 4.0])
+
+    # We'd like to make sure we only attach shared object IDs when strictly
+    # necessary, so we'll recursively traverse the generated config to count
+    # whether we have the exact number we expect.
+    def _get_all_keys_recursive(dict_or_iterable):
+      if isinstance(dict_or_iterable, dict):
+        for key in dict_or_iterable.keys():
+          yield key
+        for key in _get_all_keys_recursive(dict_or_iterable.values()):
+          yield key
+      elif isinstance(dict_or_iterable, string_types):
+        return
+      else:
+        try:
+          for item in dict_or_iterable:
+            for key in _get_all_keys_recursive(item):
+              yield key
+        # Not an iterable or dictionary
+        except TypeError:
+          return
+
+    with generic_utils.CustomObjectScope({
+        'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
+
+      # Test saving and loading to disk
+      save_format = testing_utils.get_save_format()
+      saved_model_dir = self._save_model_dir()
+      keras.models.save_model(model, saved_model_dir, save_format=save_format)
+      loaded = keras.models.load_model(saved_model_dir)
+      _do_assertions(loaded)
+
+      # Test recreating directly from config
+      config = model.get_config()
+      key_count = collections.Counter(_get_all_keys_recursive(config))
+      self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
+      loaded = keras.Model.from_config(config)
+      _do_assertions(loaded)
+
+  @combinations.generate(combinations.combine(mode=['eager']))
+  def test_shared_objects_wrapper(self):
+    """Tests that shared layers wrapped with `Wrapper` restore correctly."""
+    input_ = keras.Input(shape=(1,))
+    unwrapped = keras.layers.Layer(name='unwrapped')
+    wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
+    model = keras.Model(inputs=input_,
+                        outputs=[unwrapped(input_), wrapped(input_)])
+
+    # Test recreating directly from config
+    config = model.get_config()
+    loaded = keras.Model.from_config(config)
+    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
+
+    # Test saving and loading to disk
+    save_format = testing_utils.get_save_format()
+    saved_model_dir = self._save_model_dir()
+    keras.models.save_model(model, saved_model_dir, save_format=save_format)
+    loaded = keras.models.load_model(saved_model_dir)
+    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
+
 
 # Factory functions to create models that will be serialized inside a Network.
 def _make_graph_network(input_size, output_size):
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index e2776bc..3f59a8e 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -46,7 +46,6 @@
     # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
     # the python config serialization has caught up.
     metadata = dict(
-        class_name=generic_utils.get_registered_name(type(self.obj)),
         name=self.obj.name,
         trainable=self.obj.trainable,
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
@@ -56,7 +55,7 @@
         must_restore_from_config=self.obj._must_restore_from_config,  # pylint: disable=protected-access
     )
 
-    metadata.update(get_config(self.obj))
+    metadata.update(get_serialized(self.obj))
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
@@ -110,16 +109,12 @@
 
 # TODO(kathywu): Move serialization utils (and related utils from
 # generic_utils.py) to a separate file.
-def get_config(obj):
+def get_serialized(obj):
   with generic_utils.skip_failed_serialization():
     # Store the config dictionary, which may be used when reviving the object.
     # When loading, the program will attempt to revive the object from config,
     # and if that fails, the object will be revived from the SavedModel.
-    config = generic_utils.serialize_keras_object(obj)['config']
-
-  if config is not None:
-    return {'config': config}
-  return {}
+    return generic_utils.serialize_keras_object(obj)
 
 
 class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 217b124..fc34bf3 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -493,13 +493,15 @@
     #       found.
     class_name = metadata.get('class_name')
     config = metadata.get('config')
+    shared_object_id = metadata.get('shared_object_id')
     must_restore_from_config = metadata.get('must_restore_from_config')
     if not generic_utils.validate_config(config):
       return None
 
     try:
       obj = layers_module.deserialize(
-          generic_utils.serialize_keras_class_and_config(class_name, config))
+          generic_utils.serialize_keras_class_and_config(
+              class_name, config, shared_object_id=shared_object_id))
     except ValueError:
       if must_restore_from_config:
         raise RuntimeError(
diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
index fda341d..e2b6d36 100644
--- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py
@@ -36,7 +36,7 @@
         class_name=generic_utils.get_registered_name(type(self.obj)),
         name=self.obj.name,
         dtype=self.obj.dtype)
-    metadata.update(layer_serialization.get_config(self.obj))
+    metadata.update(layer_serialization.get_serialized(self.obj))
     if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
       metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
     return metadata
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index ecf3824..89aeaf4 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -24,8 +24,10 @@
 import os
 import re
 import sys
+import threading
 import time
 import types as python_types
+import weakref
 
 import numpy as np
 import six
@@ -110,9 +112,235 @@
   return _GLOBAL_CUSTOM_OBJECTS
 
 
-def serialize_keras_class_and_config(cls_name, cls_config):
+# Store a unique, per-object ID for shared objects.
+#
+# We store a unique ID for each object so that we may, at loading time,
+# re-create the network properly.  Without this ID, we would have no way of
+# determining whether a config is a description of a new object that
+# should be created or is merely a reference to an already-created object.
+SHARED_OBJECT_KEY = 'shared_object_id'
+
+
+SHARED_OBJECT_DISABLED = threading.local()
+SHARED_OBJECT_LOADING = threading.local()
+SHARED_OBJECT_SAVING = threading.local()
+
+
+# Attributes on the threadlocal variable must be set per-thread, thus we
+# cannot initialize these globally. Instead, we have accessor functions with
+# default values.
+def _shared_object_disabled():
+  """Get whether shared object handling is disabled in a threadsafe manner."""
+  return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
+
+
+def _shared_object_loading_scope():
+  """Get the current shared object saving scope in a threadsafe manner."""
+  return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
+
+
+def _shared_object_saving_scope():
+  """Get the current shared object saving scope in a threadsafe manner."""
+  return getattr(SHARED_OBJECT_SAVING, 'scope', None)
+
+
+class DisableSharedObjectScope(object):
+  """A context manager for disabling handling of shared objects.
+
+  Disables shared object handling for both saving and loading.
+
+  Created primarily for use with `clone_model`, which does extra surgery that
+  is incompatible with shared objects.
+  """
+
+  def __enter__(self):
+    SHARED_OBJECT_DISABLED.disabled = True
+    self._orig_loading_scope = _shared_object_loading_scope()
+    self._orig_saving_scope = _shared_object_saving_scope()
+
+  def __exit__(self, *args, **kwargs):
+    SHARED_OBJECT_DISABLED.disabled = False
+    SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
+    SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
+
+
+class NoopLoadingScope(object):
+  """The default shared object loading scope. It does nothing.
+
+  Created to simplify serialization code that doesn't care about shared objects
+  (e.g. when serializing a single object).
+  """
+
+  def get(self, unused_object_id):
+    return None
+
+  def set(self, object_id, obj):
+    pass
+
+
+class SharedObjectLoadingScope(object):
+  """A context manager for keeping track of loaded objects.
+
+  During the deserialization process, we may come across objects that are
+  shared across multiple layers. In order to accurately restore the network
+  structure to its original state, `SharedObjectLoadingScope` allows us to
+  re-use shared objects rather than cloning them.
+  """
+
+  def __enter__(self):
+    if _shared_object_disabled():
+      return NoopLoadingScope()
+
+    global SHARED_OBJECT_LOADING
+    SHARED_OBJECT_LOADING.scope = self
+    self._obj_ids_to_obj = {}
+    return self
+
+  def get(self, object_id):
+    """Given a shared object ID, returns a previously instantiated object.
+
+    Args:
+      object_id: shared object ID to use when attempting to find already-loaded
+        object.
+
+    Returns:
+      The object, if we've seen this ID before. Else, `None`.
+    """
+    # Explicitly check for `None` internally to make external calling code a
+    # bit cleaner.
+    if object_id is None:
+      return
+    return self._obj_ids_to_obj.get(object_id)
+
+  def set(self, object_id, obj):
+    """Stores an instantiated object for future lookup and sharing."""
+    if object_id is None:
+      return
+    self._obj_ids_to_obj[object_id] = obj
+
+  def __exit__(self, *args, **kwargs):
+    global SHARED_OBJECT_LOADING
+    SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
+
+
+class SharedObjectConfig(dict):
+  """A configuration container that keeps track of references.
+
+  `SharedObjectConfig` will automatically attach a shared object ID to any
+  configs which are referenced more than once, allowing for proper shared
+  object reconstruction at load time.
+
+  In most cases, it would be more proper to subclass something like
+  `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
+  Unfortunately, python's json encoder does not support `Mapping`s. This is
+  important functionality to retain, since we are dealing with serialization.
+
+  We should be safe to subclass `dict` here, since we aren't actually
+  overriding any core methods, only augmenting with a new one for reference
+  counting.
+  """
+
+  def __init__(self, base_config, object_id, **kwargs):
+    self.ref_count = 1
+    self.object_id = object_id
+    super(SharedObjectConfig, self).__init__(base_config, **kwargs)
+
+  def increment_ref_count(self):
+    # As soon as we've seen the object more than once, we want to attach the
+    # shared object ID. This allows us to only attach the shared object ID when
+    # it's strictly necessary, making backwards compatibility breakage less
+    # likely.
+    if self.ref_count == 1:
+      self[SHARED_OBJECT_KEY] = self.object_id
+    self.ref_count += 1
+
+
+class SharedObjectSavingScope(object):
+  """Keeps track of shared object configs when serializing."""
+
+  def __enter__(self):
+    if _shared_object_disabled():
+      return None
+
+    global SHARED_OBJECT_SAVING
+
+    # Serialization can happen at a number of layers for a number of reasons.
+    # We may end up with a case where we're opening a saving scope within
+    # another saving scope. In that case, we'd like to use the outermost scope
+    # available and ignore inner scopes, since there is not (yet) a reasonable
+    # use case for having these nested and distinct.
+    if _shared_object_saving_scope() is not None:
+      self._passthrough = True
+      return _shared_object_saving_scope()
+    else:
+      self._passthrough = False
+
+    SHARED_OBJECT_SAVING.scope = self
+    self._shared_objects_config = weakref.WeakKeyDictionary()
+    self._next_id = 0
+    return self
+
+  def get_config(self, obj):
+    """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
+
+    Args:
+      obj: The object for which to retrieve the `SharedObjectConfig`.
+
+    Returns:
+      The SharedObjectConfig for a given object, if already seen. Else,
+        `None`.
+    """
+    try:
+      shared_object_config = self._shared_objects_config[obj]
+    except (TypeError, KeyError):
+      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
+      # that has not overridden `__hash__`), a `TypeError` will be thrown.
+      # We'll just continue on without shared object support.
+      return None
+    shared_object_config.increment_ref_count()
+    return shared_object_config
+
+  def create_config(self, base_config, obj):
+    """Create a new SharedObjectConfig for a given object."""
+    shared_object_config = SharedObjectConfig(base_config, self._next_id)
+    self._next_id += 1
+    try:
+      self._shared_objects_config[obj] = shared_object_config
+    except TypeError:
+      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
+      # that has not overridden `__hash__`), a `TypeError` will be thrown.
+      # We'll just continue on without shared object support.
+      pass
+    return shared_object_config
+
+  def __exit__(self, *args, **kwargs):
+    if not getattr(self, '_passthrough', False):
+      global SHARED_OBJECT_SAVING
+      SHARED_OBJECT_SAVING.scope = None
+
+
+def serialize_keras_class_and_config(
+    cls_name, cls_config, obj=None, shared_object_id=None):
   """Returns the serialization of the class with the given config."""
-  return {'class_name': cls_name, 'config': cls_config}
+  base_config = {'class_name': cls_name, 'config': cls_config}
+
+  # We call `serialize_keras_class_and_config` for some branches of the load
+  # path. In that case, we may already have a shared object ID we'd like to
+  # retain.
+  if shared_object_id is not None:
+    base_config[SHARED_OBJECT_KEY] = shared_object_id
+
+  # If we have an active `SharedObjectSavingScope`, check whether we've already
+  # serialized this config. If so, just use that config. This will store an
+  # extra ID field in the config, allowing us to re-create the shared object
+  # relationship at load time.
+  if _shared_object_saving_scope() is not None and obj is not None:
+    shared_object_config = _shared_object_saving_scope().get_config(obj)
+    if shared_object_config is None:
+      return _shared_object_saving_scope().create_config(base_config, obj)
+    return shared_object_config
+
+  return base_config
 
 
 @keras_export('keras.utils.register_keras_serializable')
@@ -234,7 +462,19 @@
 
 @keras_export('keras.utils.serialize_keras_object')
 def serialize_keras_object(instance):
-  """Serialize a Keras object into a JSON-compatible representation."""
+  """Serialize a Keras object into a JSON-compatible representation.
+
+  Calls to `serialize_keras_object` while underneath the
+  `SharedObjectSavingScope` context manager will cause any objects re-used
+  across multiple layers to be saved with a special shared object ID. This
+  allows the network to be re-created properly during deserialization.
+
+  Args:
+    instance: The object to serialize.
+
+  Returns:
+    A dict-like, JSON-compatible representation of the object's config.
+  """
   _, instance = tf_decorator.unwrap(instance)
   if instance is None:
     return None
@@ -265,7 +505,8 @@
         serialization_config[key] = item
 
     name = get_registered_name(instance.__class__)
-    return serialize_keras_class_and_config(name, serialization_config)
+    return serialize_keras_class_and_config(
+        name, serialization_config, instance)
   if hasattr(instance, '__name__'):
     return get_registered_name(instance)
   raise ValueError('Cannot serialize', instance)
@@ -286,8 +527,9 @@
     custom_objects=None,
     printable_module_name='object'):
   """Returns the class name and config for a serialized keras object."""
-  if (not isinstance(config, dict) or 'class_name' not in config or
-      'config' not in config):
+  if (not isinstance(config, dict)
+      or 'class_name' not in config
+      or 'config' not in config):
     raise ValueError('Improper config format: ' + str(config))
 
   class_name = config['class_name']
@@ -341,7 +583,24 @@
                              module_objects=None,
                              custom_objects=None,
                              printable_module_name='object'):
-  """Turns the serialized form of a Keras object back into an actual object."""
+  """Turns the serialized form of a Keras object back into an actual object.
+
+  Calls to `deserialize_keras_object` while underneath the
+  `SharedObjectLoadingScope` context manager will cause any already-seen shared
+  objects to be returned as-is rather than creating a new object.
+
+  Args:
+    identifier: the serialized form of the object.
+    module_objects: A dictionary of custom objects to look the name up in.
+      Generally, module_objects is provided by midlevel library implementers.
+    custom_objects: A dictionary of custom objects to look the name up in.
+      Generally, custom_objects is provided by the user.
+    printable_module_name: A human-readable string representing the type of the
+      object. Printed in case of exception.
+
+  Returns:
+    The deserialized object.
+  """
   if identifier is None:
     return None
 
@@ -351,25 +610,39 @@
     (cls, cls_config) = class_and_config_for_serialized_keras_object(
         config, module_objects, custom_objects, printable_module_name)
 
+    # If this object has already been loaded (i.e. it's shared between multiple
+    # objects), return the already-loaded object.
+    shared_object_id = config.get(SHARED_OBJECT_KEY)
+    shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
+    if shared_object is not None:
+      return shared_object
+
     if hasattr(cls, 'from_config'):
       arg_spec = tf_inspect.getfullargspec(cls.from_config)
       custom_objects = custom_objects or {}
 
       if 'custom_objects' in arg_spec.args:
-        return cls.from_config(
+        deserialized_obj = cls.from_config(
             cls_config,
             custom_objects=dict(
                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                 list(custom_objects.items())))
-      with CustomObjectScope(custom_objects):
-        return cls.from_config(cls_config)
+      else:
+        with CustomObjectScope(custom_objects):
+          deserialized_obj = cls.from_config(cls_config)
     else:
       # Then `cls` may be a function returning a class.
       # in this case by convention `config` holds
       # the kwargs of the function.
       custom_objects = custom_objects or {}
       with CustomObjectScope(custom_objects):
-        return cls(**cls_config)
+        deserialized_obj = cls(**cls_config)
+
+    # Add object to shared objects, in case we find it referenced again.
+    _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
+
+    return deserialized_obj
+
   elif isinstance(identifier, six.string_types):
     object_name = identifier
     if custom_objects and object_name in custom_objects:
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index 2dc2952..dd28b17 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -23,6 +23,7 @@
 import numpy as np
 
 from tensorflow.python import keras
+from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.platform import test
 
 
@@ -384,5 +385,63 @@
         [None, None, None])
 
 
+# object() alone isn't compatible with WeakKeyDictionary, which we use to
+# track shared configs.
+class MaybeSharedObject(object):
+  pass
+
+
+class SharedObjectScopeTest(test.TestCase):
+
+  def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
+    with generic_utils.SharedObjectSavingScope() as scope:
+      single_object = MaybeSharedObject()
+      self.assertIsNone(scope.get_config(single_object))
+      single_object_config = scope.create_config({}, single_object)
+      self.assertIsNotNone(single_object_config)
+      self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
+                       single_object_config)
+
+  def test_shared_object_saving_scope_shared_object_exports_id(self):
+    with generic_utils.SharedObjectSavingScope() as scope:
+      shared_object = MaybeSharedObject()
+      self.assertIsNone(scope.get_config(shared_object))
+      scope.create_config({}, shared_object)
+      first_object_config = scope.get_config(shared_object)
+      second_object_config = scope.get_config(shared_object)
+      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
+                    first_object_config)
+      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
+                    second_object_config)
+      self.assertIs(first_object_config, second_object_config)
+
+  def test_shared_object_loading_scope_noop(self):
+    # Test that, without a context manager scope, adding configs will do
+    # nothing.
+    obj_id = 1
+    obj = MaybeSharedObject()
+    generic_utils._shared_object_loading_scope().set(obj_id, obj)
+    self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
+
+  def test_shared_object_loading_scope_returns_shared_obj(self):
+    obj_id = 1
+    obj = MaybeSharedObject()
+    with generic_utils.SharedObjectLoadingScope() as scope:
+      scope.set(obj_id, obj)
+      self.assertIs(scope.get(obj_id), obj)
+
+  def test_nested_shared_object_saving_scopes(self):
+    my_obj = MaybeSharedObject()
+    with generic_utils.SharedObjectSavingScope() as scope_1:
+      scope_1.create_config({}, my_obj)
+      with generic_utils.SharedObjectSavingScope() as scope_2:
+        # Nesting saving scopes should return the original scope and should
+        # not clear any objects we're tracking.
+        self.assertIs(scope_1, scope_2)
+        self.assertIsNotNone(scope_2.get_config(my_obj))
+      self.assertIsNotNone(scope_1.get_config(my_obj))
+    self.assertIsNone(generic_utils._shared_object_saving_scope())
+
+
 if __name__ == '__main__':
   test.main()