When exporting SavedModel, force functional and sequential models to save config even if error occurs when serializing layers.
This ensures that the network structure is saved even if a custom layer doesn't define its config.
PiperOrigin-RevId: 283633369
Change-Id: I59c4e7dcb9acca837bc6534af046c7f21663ff24
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index c3abf49..88b6165 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1857,6 +1857,18 @@
)
tf_py_test(
+ name = "model_serialization_test",
+ size = "medium",
+ srcs = ["saving/saved_model/model_serialization_test.py"],
+ additional_deps = [
+ ":keras",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/distribute:mirrored_strategy",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+tf_py_test(
name = "saving_utils_test",
size = "medium",
srcs = ["saving/saving_utils_test.py"],
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 369cd31..522aed6 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -345,13 +345,17 @@
layer_configs = []
for layer in self.layers:
layer_configs.append(generic_utils.serialize_keras_object(layer))
- # When constructed using an `InputLayer` the first non-input layer may not
- # have the shape information to reconstruct `Sequential` as a graph network.
- if (self._is_graph_network and layer_configs and
- 'batch_input_shape' not in layer_configs[0]['config'] and
- isinstance(self._layers[0], input_layer.InputLayer)):
- batch_input_shape = self._layers[0]._batch_input_shape
- layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
+
+ if layer_configs and layer_configs[0]['config'] is not None:
+ # layer_configs[0]['config'] may be None only when saving SavedModel.
+
+ # Check to see whether the first non-input layer has the shape information
+ # to reconstruct `Sequential` as a graph network. If not, add it.
+ if (self._is_graph_network and
+ 'batch_input_shape' not in layer_configs[0]['config'] and
+ isinstance(self._layers[0], input_layer.InputLayer)):
+ batch_input_shape = self._layers[0]._batch_input_shape
+ layer_configs[0]['config']['batch_input_shape'] = batch_input_shape
config = {
'name': self.name,
diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
index 054a01e..ab1edaa 100644
--- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py
+++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py
@@ -23,7 +23,7 @@
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import save_impl
from tensorflow.python.keras.saving.saved_model import serialized_attributes
-from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.util import nest
@@ -51,23 +51,22 @@
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access
batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
- try:
- # Store the config dictionary, which is only used by the revived object
- # to return the original config when revived_obj.get_config() is called.
- # It is not important for recreating the revived object.
- metadata['config'] = self.obj.get_config()
- except NotImplementedError:
- # in the case of a subclassed model, the get_config() method will throw
- # a NotImplementedError.
- pass
+
+ with generic_utils.skip_failed_serialization():
+ # Store the config dictionary, which may be used when reviving the object.
+ # When loading, the program will attempt to revive the object from config,
+ # and if that fails, the object will be revived from the SavedModel.
+ config = generic_utils.serialize_keras_object(self.obj)['config']
+ if config is not None:
+ metadata['config'] = config
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure(
- lambda x: None if x is None else serialize_keras_object(x),
+ lambda x: generic_utils.serialize_keras_object(x) if x else None,
self.obj.input_spec)
if (self.obj.activity_regularizer is not None and
hasattr(self.obj.activity_regularizer, 'get_config')):
- metadata['activity_regularizer'] = serialize_keras_object(
+ metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
self.obj.activity_regularizer)
return metadata
diff --git a/tensorflow/python/keras/saving/saved_model/model_serialization_test.py b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
new file mode 100644
index 0000000..125ab2f
--- /dev/null
+++ b/tensorflow/python/keras/saving/saved_model/model_serialization_test.py
@@ -0,0 +1,48 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Unit tests for serializing Keras models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.platform import test
+
+
+class CustomLayer(keras.layers.Layer):
+
+ def __init__(self, unused_a):
+ super(CustomLayer, self).__init__()
+
+
+class ModelSerializationTest(keras_parameterized.TestCase):
+
+ @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
+ def test_model_config_always_saved(self):
+ layer = CustomLayer(None)
+ with self.assertRaisesRegexp(NotImplementedError,
+ 'must override `get_config`.'):
+ layer.get_config()
+ model = testing_utils.get_model_from_layers([layer], input_shape=(3,))
+ properties = model._trackable_saved_model_saver.python_properties
+ self.assertIsNotNone(properties['config'])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 8ff27a3..8b899dc 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -30,6 +30,7 @@
import six
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@@ -37,6 +38,11 @@
_GLOBAL_CUSTOM_OBJECTS = {}
_GLOBAL_CUSTOM_NAMES = {}
+# Flag that determines whether to skip the NotImplementedError when calling
+# get_config in custom models and layers. This is only enabled when saving to
+# SavedModel, when the config isn't required.
+_SKIP_FAILED_SERIALIZATION = False
+
@keras_export('keras.utils.CustomObjectScope')
class CustomObjectScope(object):
@@ -187,6 +193,17 @@
return obj.__name__
+@tf_contextlib.contextmanager
+def skip_failed_serialization():
+ global _SKIP_FAILED_SERIALIZATION
+ prev = _SKIP_FAILED_SERIALIZATION
+ try:
+ _SKIP_FAILED_SERIALIZATION = True
+ yield
+ finally:
+ _SKIP_FAILED_SERIALIZATION = prev
+
+
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
"""Serialize Keras object into JSON."""
@@ -195,7 +212,13 @@
return None
if hasattr(instance, 'get_config'):
- config = instance.get_config()
+ name = _get_name_or_custom_name(instance.__class__)
+ try:
+ config = instance.get_config()
+ except NotImplementedError as e:
+ if _SKIP_FAILED_SERIALIZATION:
+ return serialize_keras_class_and_config(name, None)
+ raise e
serialization_config = {}
for key, item in config.items():
if isinstance(item, six.string_types):
@@ -211,15 +234,13 @@
serialization_config[key] = serialized_item
except ValueError:
serialization_config[key] = item
-
- name = _get_name_or_custom_name(instance.__class__)
return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
return _get_name_or_custom_name(instance)
raise ValueError('Cannot serialize', instance)
-def _get_custom_objects_by_name(item, custom_objects=None):
+def get_custom_objects_by_name(item, custom_objects=None):
"""Returns the item if it is in either local or global custom objects."""
if item in _GLOBAL_CUSTOM_OBJECTS:
return _GLOBAL_CUSTOM_OBJECTS[item]
@@ -260,7 +281,7 @@
printable_module_name='config_item')
elif (isinstance(item, six.string_types) and
tf_inspect.isfunction(
- _get_custom_objects_by_name(item, custom_objects))):
+ get_custom_objects_by_name(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the
@@ -269,7 +290,7 @@
# conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict.
- deserialized_objects[key] = _get_custom_objects_by_name(
+ deserialized_objects[key] = get_custom_objects_by_name(
item, custom_objects)
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]