Save Keras metadata in a separate folder and raise deprecation warnings when loading a SavedModel with tf.saved_model.save().

PiperOrigin-RevId: 338359077
Change-Id: I93d8c345efb323cd8d4fd1fda4c8e5e86b37d620
diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD
index 51095c1..7dcc9ae 100644
--- a/tensorflow/python/keras/saving/BUILD
+++ b/tensorflow/python/keras/saving/BUILD
@@ -49,6 +49,7 @@
     deps = [
         "//tensorflow/python:lib",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform",
         "//tensorflow/python:saver",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python/eager:def_function",
diff --git a/tensorflow/python/keras/saving/saved_model/constants.py b/tensorflow/python/keras/saving/saved_model/constants.py
index 3f1eca9..12265e0 100644
--- a/tensorflow/python/keras/saving/saved_model/constants.py
+++ b/tensorflow/python/keras/saving/saved_model/constants.py
@@ -26,3 +26,7 @@
 # Keys for the serialization cache.
 # Maps to the keras serialization dict {Layer --> SerializedAttributes object}
 KERAS_CACHE_KEY = 'keras_serialized_attributes'
+
+
+# Name of Keras metadata file stored in the SavedModel.
+SAVED_METADATA_PATH = 'keras_metadata.pb'
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index cb6d340..43c1d2b 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -17,9 +17,12 @@
 from __future__ import division
 from __future__ import print_function
 
+import os
 import re
 import types
 
+from google.protobuf import message
+
 from tensorflow.core.framework import versions_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as defun
@@ -38,6 +41,7 @@
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import metrics_utils
 from tensorflow.python.keras.utils.generic_utils import LazyLoader
+from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import load as tf_load
 from tensorflow.python.saved_model import loader_impl
@@ -121,13 +125,26 @@
   # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
   # TODO(kathywu): Add code to load from objects that contain all endpoints
 
-  # The Keras metadata file is not yet saved, so create it from the SavedModel.
+  # Look for metadata file or parse the SavedModel
   metadata = saved_metadata_pb2.SavedMetadata()
   meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
   object_graph_def = meta_graph_def.object_graph_def
-  # TODO(kathywu): When the keras metadata file is saved, load it directly
-  # instead of calling the _read_legacy_metadata function.
-  _read_legacy_metadata(object_graph_def, metadata)
+  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
+  if gfile.Exists(path_to_metadata_pb):
+    try:
+      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
+        file_content = f.read()
+      metadata.ParseFromString(file_content)
+    except message.DecodeError as e:
+      raise IOError('Cannot parse keras metadata {}: {}.'
+                    .format(path_to_metadata_pb, str(e)))
+  else:
+    logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
+                    'Keras model. Please ensure that you are saving the model '
+                    'with model.save() or tf.keras.models.save_model(), *NOT* '
+                    'tf.saved_model.save(). To confirm, there should be a file '
+                    'named "keras_metadata.pb" in the SavedModel directory.')
+    _read_legacy_metadata(object_graph_def, metadata)
 
   if not metadata.nodes:
     # When there are no Keras objects, return the results from the core loader
diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py
index 16984a2..2ab7ebb 100644
--- a/tensorflow/python/keras/saving/saved_model/save.py
+++ b/tensorflow/python/keras/saving/saved_model/save.py
@@ -18,15 +18,21 @@
 from __future__ import print_function
 
 import os
+
+from tensorflow.core.framework import versions_pb2
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.protobuf import saved_metadata_pb2
 from tensorflow.python.keras.saving import saving_utils
+from tensorflow.python.keras.saving.saved_model import constants
 from tensorflow.python.keras.saving.saved_model import save_impl
 from tensorflow.python.keras.saving.saved_model import utils
 from tensorflow.python.keras.utils.generic_utils import LazyLoader
 from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.platform import gfile
 from tensorflow.python.saved_model import save as save_lib
 
+
 # To avoid circular dependencies between keras/engine and keras/saving,
 # code in keras/saving must delay imports.
 
@@ -86,7 +92,39 @@
     # we use the default replica context here.
     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
       with utils.keras_option_scope(save_traces):
-        save_lib.save(model, filepath, signatures, options)
+        saved_nodes, node_paths = save_lib.save_and_return_nodes(
+            model, filepath, signatures, options)
+
+    # Save all metadata to a separate file in the SavedModel directory.
+    metadata = generate_keras_metadata(saved_nodes, node_paths)
+
+  with gfile.GFile(
+      os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
+    w.write(metadata.SerializeToString(deterministic=True))
 
   if not include_optimizer:
     model.optimizer = orig_optimizer
+
+
+def generate_keras_metadata(saved_nodes, node_paths):
+  """Constructs a KerasMetadata proto with the metadata of each keras object."""
+  metadata = saved_metadata_pb2.SavedMetadata()
+
+  for node_id, node in enumerate(saved_nodes):
+    if isinstance(node, base_layer.Layer):
+      path = node_paths[node]
+      if not path:
+        node_path = "root"
+      else:
+        node_path = "root.{}".format(
+            ".".join([ref.name for ref in path]))
+
+      metadata.nodes.add(
+          node_id=node_id,
+          node_path=node_path,
+          version=versions_pb2.VersionDef(
+              producer=1, min_consumer=1, bad_consumers=[]),
+          identifier=node._object_identifier,  # pylint: disable=protected-access
+          metadata=node._tracking_metadata)  # pylint: disable=protected-access
+
+  return metadata
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 27a2867..76af988 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -180,8 +180,9 @@
     """
     self.options = options
     self.checkpoint_view = checkpoint_view
-    trackable_objects, node_ids, slot_variables = (
-        self.checkpoint_view.objects_ids_and_slot_variables())
+    trackable_objects, path_to_root, node_ids, slot_variables = (
+        self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
+    self.node_paths = path_to_root
     self.nodes = trackable_objects
     self.node_ids = node_ids
     self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
@@ -1021,6 +1022,30 @@
   May not be called from within a function body.
   @end_compatibility
   """
+  save_and_return_nodes(obj, export_dir, signatures, options,
+                        raise_metadata_warning=True)
+
+
+def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
+                          raise_metadata_warning=False):
+  """Saves a SavedModel while returning all saved nodes and their paths.
+
+  Please see `tf.saved_model.save` for details.
+
+  Args:
+    obj: A trackable object to export.
+    export_dir: A directory in which to write the SavedModel.
+    signatures: A function or dictionary of functions to save in the SavedModel
+      as signatures.
+    options: `tf.saved_model.SaveOptions` object for configuring save options.
+    raise_metadata_warning: Whether to raise the metadata warning. This arg will
+      be removed in TF 2.5.
+
+  Returns:
+    A tuple of (a list of saved nodes in the order they are serialized to the
+      `SavedObjectGraph`, dictionary mapping nodes to one possible path from
+      the root node to the key node)
+  """
   options = options or save_options.SaveOptions()
   # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
   # compatible (no sessions) and share it with this export API rather than
@@ -1028,8 +1053,9 @@
   saved_model = saved_model_pb2.SavedModel()
   meta_graph_def = saved_model.meta_graphs.add()
 
-  _, exported_graph, object_saver, asset_info = _build_meta_graph(
-      obj, signatures, options, meta_graph_def)
+  _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
+      _build_meta_graph(obj, signatures, options, meta_graph_def,
+                        raise_metadata_warning))
   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
 
   # Write the checkpoint, copy assets into the assets directory, and write out
@@ -1069,6 +1095,8 @@
   # constants in the saved graph.
   ops.dismantle_graph(exported_graph)
 
+  return saved_nodes, node_paths
+
 
 def export_meta_graph(obj, filename, signatures=None, options=None):
   """Exports the MetaGraph proto of the `obj` to a file.
@@ -1095,7 +1123,7 @@
   """
   options = options or save_options.SaveOptions()
   export_dir = os.path.dirname(filename)
-  meta_graph_def, exported_graph, _, _ = _build_meta_graph(
+  meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
       obj, signatures, options)
 
   file_io.atomic_write_string_to_file(
@@ -1114,7 +1142,8 @@
 def _build_meta_graph_impl(obj,
                            signatures,
                            options,
-                           meta_graph_def=None):
+                           meta_graph_def=None,
+                           raise_metadata_warning=True):
   """Creates a MetaGraph containing the resources and functions of an object."""
   if ops.inside_function():
     raise AssertionError(
@@ -1162,7 +1191,7 @@
       saveable_view, asset_info.asset_index)
   meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
 
-  if saved_object_metadata:
+  if saved_object_metadata and raise_metadata_warning:
     tf_logging.warn(
         'FOR KERAS USERS: The object that you are saving contains one or more '
         'Keras models or layers. If you are loading the SavedModel with '
@@ -1178,13 +1207,15 @@
         'metadta field will be deprecated soon, so please move the metadata to '
         'a different file.')
 
-  return (meta_graph_def, exported_graph, object_saver, asset_info)
+  return (meta_graph_def, exported_graph, object_saver, asset_info,
+          saveable_view.nodes, saveable_view.node_paths)
 
 
 def _build_meta_graph(obj,
                       signatures,
                       options,
-                      meta_graph_def=None):
+                      meta_graph_def=None,
+                      raise_metadata_warning=True):
   """Creates a MetaGraph under a save context.
 
   Args:
@@ -1197,6 +1228,8 @@
     options: `tf.saved_model.SaveOptions` object that specifies options for
       saving.
     meta_graph_def: Optional, the MetaGraphDef proto fill.
+    raise_metadata_warning: Whether to raise a warning when user objects contain
+      non-empty metadata.
 
   Raises:
     AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
@@ -1210,4 +1243,5 @@
   """
 
   with save_context.save_context(options):
-    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
+    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
+                                  raise_metadata_warning)
diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py
index 6aeb41b..61078cc 100644
--- a/tensorflow/python/training/tracking/graph_view.py
+++ b/tensorflow/python/training/tracking/graph_view.py
@@ -430,7 +430,7 @@
               name=base.OBJECT_GRAPH_PROTO_KEY))
     return named_saveable_objects
 
-  def objects_ids_and_slot_variables(self):
+  def objects_ids_and_slot_variables_and_paths(self):
     """Traverse the object graph and list all accessible objects.
 
     Looks for `Trackable` objects which are dependencies of
@@ -439,7 +439,8 @@
     (i.e. if they would be saved with a checkpoint).
 
     Returns:
-      A tuple of (trackable objects, object -> node id, slot variables)
+      A tuple of (trackable objects, paths from root for each object,
+                  object -> node id, slot variables)
     """
     trackable_objects, path_to_root = self._breadth_first_traversal()
     object_names = object_identity.ObjectIdentityDictionary()
@@ -452,6 +453,11 @@
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         object_names=object_names)
+    return trackable_objects, path_to_root, node_ids, slot_variables
+
+  def objects_ids_and_slot_variables(self):
+    trackable_objects, _, node_ids, slot_variables = (
+        self.objects_ids_and_slot_variables_and_paths())
     return trackable_objects, node_ids, slot_variables
 
   def list_objects(self):