When exporting SavedModel, force functional and sequential models to save config even if error occurs when serializing layers.

This ensures that the network structure is saved even if a custom layer doesn't define its config.

PiperOrigin-RevId: 283633369
Change-Id: I59c4e7dcb9acca837bc6534af046c7f21663ff24
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index c3abf49..88b6165 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1857,6 +1857,18 @@
 )
 
 tf_py_test(
+    name = "model_serialization_test",
+    size = "medium",
+    srcs = ["saving/saved_model/model_serialization_test.py"],
+    additional_deps = [
+        ":keras",
+        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/python/distribute:mirrored_strategy",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+tf_py_test(
     name = "saving_utils_test",
     size = "medium",
     srcs = ["saving/saving_utils_test.py"],
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 369cd31..522aed6 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -345,13 +345,17 @@
     layer_configs = []
     for layer in self.layers:
       layer_configs.append(generic_utils.serialize_keras_object(layer))
-    # When constructed using an `InputLayer` the first non-input layer may not
-    # have the shape information to reconstruct `Sequential` as a graph network.
-    if (self._is_graph_network and layer_configs and
-        'batch_input_shape' not in layer_configs[0]['config'] and
-        isinstance(self._layers[0], input_layer.InputLayer)):
-      batch_input_shape = self._layers[0]._batch_input_shape
-      layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
+
+    if layer_configs and layer_configs[0]['config'] is not None:
+      # layer_configs[0]['config'] may be None only when saving SavedModel.
+
+      # Check to see whether the first non-input layer has the shape information
+      # to reconstruct `Sequential` as a graph network. If not, add it.
+      if (self._is_graph_network and
+          'batch_input_shape' not in layer_configs[0]['config'] and
+          isinstance(self._layers[0], input_layer.InputLayer)):
+        batch_input_shape = self._layers[0]._batch_input_shape
+        layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
 
     config = {
         'name': self.name,
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 054a01e..ab1edaa 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -23,7 +23,7 @@
 from tensorflow.python.keras.saving.saved_model import constants
 from tensorflow.python.keras.saving.saved_model import save_impl
 from tensorflow.python.keras.saving.saved_model import serialized_attributes
-from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.util import nest
 
 
@@ -51,23 +51,22 @@
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
         dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
         batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
-    try:
-      # Store the config dictionary, which is only used by the revived object
-      # to return the original config when revived_obj.get_config() is called.
-      # It is not important for recreating the revived object.
-      metadata['config'] = self.obj.get_config()
-    except NotImplementedError:
-      # in the case of a subclassed model, the get_config() method will throw
-      # a NotImplementedError.
-      pass
+
+    with generic_utils.skip_failed_serialization():
+      # Store the config dictionary, which may be used when reviving the object.
+      # When loading, the program will attempt to revive the object from config,
+      # and if that fails, the object will be revived from the SavedModel.
+      config = generic_utils.serialize_keras_object(self.obj)['config']
+      if config is not None:
+        metadata['config'] = config
     if self.obj.input_spec is not None:
       # Layer's input_spec has already been type-checked in the property setter.
       metadata['input_spec'] = nest.map_structure(
-          lambda x: None if x is None else serialize_keras_object(x),
+          lambda x: generic_utils.serialize_keras_object(x) if x else None,
           self.obj.input_spec)
     if (self.obj.activity_regularizer is not None and
         hasattr(self.obj.activity_regularizer, 'get_config')):
-      metadata['activity_regularizer'] = serialize_keras_object(
+      metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
           self.obj.activity_regularizer)
     return metadata
 
diff --git a/tensorflow/python/keras/saving/saved_model/model_serialization_test.py b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
new file mode 100644
index 0000000..125ab2f
--- /dev/null
+++ b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
@@ -0,0 +1,48 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Unit tests for serializing Keras models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.platform import test
+
+
+class CustomLayer(keras.layers.Layer):
+
+  def __init__(self, unused_a):
+    super(CustomLayer, self).__init__()
+
+
+class ModelSerializationTest(keras_parameterized.TestCase):
+
+  @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
+  def test_model_config_always_saved(self):
+    layer = CustomLayer(None)
+    with self.assertRaisesRegexp(NotImplementedError,
+                                 'must override `get_config`.'):
+      layer.get_config()
+    model = testing_utils.get_model_from_layers([layer], input_shape=(3,))
+    properties = model._trackable_saved_model_saver.python_properties
+    self.assertIsNotNone(properties['config'])
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 8ff27a3..8b899dc 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -30,6 +30,7 @@
 import six
 
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
@@ -37,6 +38,11 @@
 _GLOBAL_CUSTOM_OBJECTS = {}
 _GLOBAL_CUSTOM_NAMES = {}
 
+# Flag that determines whether to skip the NotImplementedError when calling
+# get_config in custom models and layers. This is only enabled when saving to
+# SavedModel, when the config isn't required.
+_SKIP_FAILED_SERIALIZATION = False
+
 
 @keras_export('keras.utils.CustomObjectScope')
 class CustomObjectScope(object):
@@ -187,6 +193,17 @@
     return obj.__name__
 
 
+@tf_contextlib.contextmanager
+def skip_failed_serialization():
+  global _SKIP_FAILED_SERIALIZATION
+  prev = _SKIP_FAILED_SERIALIZATION
+  try:
+    _SKIP_FAILED_SERIALIZATION = True
+    yield
+  finally:
+    _SKIP_FAILED_SERIALIZATION = prev
+
+
 @keras_export('keras.utils.serialize_keras_object')
 def serialize_keras_object(instance):
   """Serialize Keras object into JSON."""
@@ -195,7 +212,13 @@
     return None
 
   if hasattr(instance, 'get_config'):
-    config = instance.get_config()
+    name = _get_name_or_custom_name(instance.__class__)
+    try:
+      config = instance.get_config()
+    except NotImplementedError as e:
+      if _SKIP_FAILED_SERIALIZATION:
+        return serialize_keras_class_and_config(name, None)
+      raise e
     serialization_config = {}
     for key, item in config.items():
       if isinstance(item, six.string_types):
@@ -211,15 +234,13 @@
         serialization_config[key] = serialized_item
       except ValueError:
         serialization_config[key] = item
-
-    name = _get_name_or_custom_name(instance.__class__)
     return serialize_keras_class_and_config(name, serialization_config)
   if hasattr(instance, '__name__'):
     return _get_name_or_custom_name(instance)
   raise ValueError('Cannot serialize', instance)
 
 
-def _get_custom_objects_by_name(item, custom_objects=None):
+def get_custom_objects_by_name(item, custom_objects=None):
   """Returns the item if it is in either local or global custom objects."""
   if item in _GLOBAL_CUSTOM_OBJECTS:
     return _GLOBAL_CUSTOM_OBJECTS[item]
@@ -260,7 +281,7 @@
           printable_module_name='config_item')
     elif (isinstance(item, six.string_types) and
           tf_inspect.isfunction(
-              _get_custom_objects_by_name(item, custom_objects))):
+              get_custom_objects_by_name(item, custom_objects))):
       # Handle custom functions here. When saving functions, we only save the
       # function's name as a string. If we find a matching string in the custom
       # objects during deserialization, we convert the string back to the
@@ -269,7 +290,7 @@
       # conflict with a custom function name, but this should be a rare case.
       # This issue does not occur if a string field has a naming conflict with
       # a custom object, since the config of an object will always be a dict.
-      deserialized_objects[key] = _get_custom_objects_by_name(
+      deserialized_objects[key] = get_custom_objects_by_name(
           item, custom_objects)
   for key, item in deserialized_objects.items():
     cls_config[key] = deserialized_objects[key]