Add specific error for functions capturing Keras learning phase, and fix keras saving tests.
PiperOrigin-RevId: 262630016
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b289ce5..94727c6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1153,6 +1153,11 @@
ctx = context.context()
executing_eagerly = ctx.executing_eagerly()
+ # Copy saveable status of function's graph to current FuncGraph.
+ default_graph = ops.get_default_graph()
+ if default_graph.building_function and not self._func_graph.saveable:
+ default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
+
if any(isinstance(a, composite_tensor.CompositeTensor) for a in args):
raise AssertionError("Expected all args to be Tensors or Variables; "
"but got CompositeTensor: %r" % args)
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index c8a9b69..30db860 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -249,6 +249,12 @@
else:
self._collections = collections
+ # Keep track of whether this FuncGraph is exportable to SavedModel. Use
+ # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
+ # dependent functions as unsaveable.
+ self._saveable = True
+ self._saving_errors = set()
+
def __str__(self):
return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
@@ -701,6 +707,31 @@
if ops.tensor_id(v.handle) in self._captures
}
+ def mark_as_unsaveable(self, error_message):
+ """Marks this FuncGraph as unsaveable.
+
+ Any attempts to export this FuncGraph will raise an error with the specified
+ message.
+
+ Args:
+ error_message: List or string containing the error message to be raised
+ when saving this FuncGraph to SavedModel.
+ """
+ self._saveable = False
+ if isinstance(error_message, str):
+ error_message = [error_message]
+ self._saving_errors.update(error_message)
+
+ @property
+ def saveable(self):
+ """Returns whether this FuncGraph is saveable."""
+ return self._saveable
+
+ @property
+ def saving_errors(self):
+ """Returns set of errors preventing this FuncGraph from being saved."""
+ return self._saving_errors
+
def func_graph_from_py_func(name,
python_func,
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index e803c88..e6c06ee 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -271,10 +271,13 @@
Returns:
Learning phase (scalar integer tensor or Python integer).
"""
- if ops.get_default_graph() is _GRAPH:
+ graph = ops.get_default_graph()
+ if graph is _GRAPH:
# Don't enter an init_scope for the learning phase if eager execution
# is enabled but we're inside the Keras workspace graph.
- return symbolic_learning_phase()
+ learning_phase = symbolic_learning_phase()
+ _mark_func_graph_as_unsaveable(graph, learning_phase)
+ return learning_phase
with ops.init_scope():
# We always check & set the learning phase inside the init_scope,
# otherwise the wrong default_graph will be used to look up the learning
@@ -288,13 +291,34 @@
# Fallback to inference mode as default.
return 0
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
- return symbolic_learning_phase()
+ learning_phase = symbolic_learning_phase()
+ _mark_func_graph_as_unsaveable(graph, learning_phase)
+ return learning_phase
def global_learning_phase_is_set():
return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES
+def _mark_func_graph_as_unsaveable(graph, learning_phase):
+ """Mark func graph as unsaveable due to use of symbolic keras learning phase.
+
+ Functions that capture the symbolic learning phase cannot be exported to
+ SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
+ if it is exported.
+
+ Args:
+ graph: Graph or FuncGraph object.
+ learning_phase: Learning phase placeholder or int defined in the graph.
+ """
+ if graph.building_function and is_placeholder(learning_phase):
+ graph.mark_as_unsaveable(
+ 'The keras learning phase placeholder was used inside a function. '
+ 'Exporting placeholders is not supported when saving out a SavedModel. '
+ 'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
+ 'to set the learning phase to a constant value.')
+
+
def symbolic_learning_phase():
graph = get_graph()
with graph.as_default():
diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py
index e0557b0..b495a03 100644
--- a/tensorflow/python/keras/saving/saved_model/save.py
+++ b/tensorflow/python/keras/saving/saved_model/save.py
@@ -86,7 +86,10 @@
orig_optimizer = model.optimizer
model.optimizer = None
- save_lib.save(model, filepath, signatures)
+ # Trace all functions and signatures with `training=0` instead of using the
+ # default learning phase placeholder.
+ with K.learning_phase_scope(0):
+ save_lib.save(model, filepath, signatures)
if not include_optimizer:
model.optimizer = orig_optimizer
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index f2d4bcb..829d90d 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -330,7 +330,7 @@
self.evaluate(loaded(input_arr_1, training=False))
self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
- def save_with_signatures(self):
+ def testSaveWithSignatures(self):
model = keras.models.Sequential()
model.add(keras.layers.Dense(5, input_shape=(3,),
kernel_regularizer=regularizers.get('l2')))
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index e12c03d..726180b 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -274,6 +274,11 @@
self.captured_tensor_node_ids[obj.asset_path] = node_id
for concrete_function in self.concrete_functions:
+ if not concrete_function.graph.saveable:
+ raise ValueError(
+ ("Unable to save function {name} for the following reason(s):\n" +
+ "\n".join(concrete_function.graph.saving_errors))
+ .format(name=concrete_function.name))
for capture in concrete_function.captured_inputs:
if (tensor_util.is_tensor(capture)
and capture.dtype not in _UNCOPIABLE_DTYPES
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 566c508..a5200b0 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -141,6 +141,22 @@
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
signatures=root.f)
+ def test_unsaveable_func_graph(self):
+ root = module.Module()
+
+ @def_function.function(input_signature=[])
+ def nested_f():
+ ops.get_default_graph().mark_as_unsaveable("ERROR MSG")
+ return 1
+
+ @def_function.function(input_signature=[])
+ def f():
+ return nested_f()
+
+ root.f = f
+ with self.assertRaisesRegexp(ValueError, "ERROR MSG"):
+ save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
+
def test_version_information_included(self):
root = tracking.AutoTrackable()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")