Add specific error for functions capturing Keras learning phase, and fix keras saving tests.

PiperOrigin-RevId: 262630016
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b289ce5..94727c6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1153,6 +1153,11 @@
     ctx = context.context()
     executing_eagerly = ctx.executing_eagerly()
 
+    # Copy saveable status of function's graph to current FuncGraph.
+    default_graph = ops.get_default_graph()
+    if default_graph.building_function and not self._func_graph.saveable:
+      default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
+
     if any(isinstance(a, composite_tensor.CompositeTensor) for a in args):
       raise AssertionError("Expected all args to be Tensors or Variables; "
                            "but got CompositeTensor: %r" % args)
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index c8a9b69..30db860 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -249,6 +249,12 @@
     else:
       self._collections = collections
 
+    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
+    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
+    # dependent functions as unsaveable.
+    self._saveable = True
+    self._saving_errors = set()
+
   def __str__(self):
     return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
 
@@ -701,6 +707,31 @@
         if ops.tensor_id(v.handle) in self._captures
     }
 
+  def mark_as_unsaveable(self, error_message):
+    """Marks this FuncGraph as unsaveable.
+
+    Any attempts to export this FuncGraph will raise an error with the specified
+    message.
+
+    Args:
+      error_message: List or string containing the error message to be raised
+        when saving this FuncGraph to SavedModel.
+    """
+    self._saveable = False
+    if isinstance(error_message, str):
+      error_message = [error_message]
+    self._saving_errors.update(error_message)
+
+  @property
+  def saveable(self):
+    """Returns whether this FuncGraph is saveable."""
+    return self._saveable
+
+  @property
+  def saving_errors(self):
+    """Returns set of errors preventing this FuncGraph from being saved."""
+    return self._saving_errors
+
 
 def func_graph_from_py_func(name,
                             python_func,
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index e803c88..e6c06ee 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -271,10 +271,13 @@
   Returns:
       Learning phase (scalar integer tensor or Python integer).
   """
-  if ops.get_default_graph() is _GRAPH:
+  graph = ops.get_default_graph()
+  if graph is _GRAPH:
     # Don't enter an init_scope for the learning phase if eager execution
     # is enabled but we're inside the Keras workspace graph.
-    return symbolic_learning_phase()
+    learning_phase = symbolic_learning_phase()
+    _mark_func_graph_as_unsaveable(graph, learning_phase)
+    return learning_phase
   with ops.init_scope():
     # We always check & set the learning phase inside the init_scope,
     # otherwise the wrong default_graph will be used to look up the learning
@@ -288,13 +291,34 @@
         # Fallback to inference mode as default.
         return 0
       return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
-    return symbolic_learning_phase()
+    learning_phase = symbolic_learning_phase()
+    _mark_func_graph_as_unsaveable(graph, learning_phase)
+    return learning_phase
 
 
 def global_learning_phase_is_set():
   return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES
 
 
+def _mark_func_graph_as_unsaveable(graph, learning_phase):
+  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
+
+  Functions that capture the symbolic learning phase cannot be exported to
+  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
+  if it is exported.
+
+  Args:
+    graph: Graph or FuncGraph object.
+    learning_phase: Learning phase placeholder or int defined in the graph.
+  """
+  if graph.building_function and is_placeholder(learning_phase):
+    graph.mark_as_unsaveable(
+        'The keras learning phase placeholder was used inside a function. '
+        'Exporting placeholders is not supported when saving out a SavedModel. '
+        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
+        'to set the learning phase to a constant value.')
+
+
 def symbolic_learning_phase():
   graph = get_graph()
   with graph.as_default():
diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py
index e0557b0..b495a03 100644
--- a/tensorflow/python/keras/saving/saved_model/save.py
+++ b/tensorflow/python/keras/saving/saved_model/save.py
@@ -86,7 +86,10 @@
     orig_optimizer = model.optimizer
     model.optimizer = None
 
-  save_lib.save(model, filepath, signatures)
+  # Trace all functions and signatures with `training=0` instead of using the
+  # default learning phase placeholder.
+  with K.learning_phase_scope(0):
+    save_lib.save(model, filepath, signatures)
 
   if not include_optimizer:
     model.optimizer = orig_optimizer
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index f2d4bcb..829d90d 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -330,7 +330,7 @@
     self.evaluate(loaded(input_arr_1, training=False))
     self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
 
-  def save_with_signatures(self):
+  def testSaveWithSignatures(self):
     model = keras.models.Sequential()
     model.add(keras.layers.Dense(5, input_shape=(3,),
                                  kernel_regularizer=regularizers.get('l2')))
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index e12c03d..726180b 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -274,6 +274,11 @@
         self.captured_tensor_node_ids[obj.asset_path] = node_id
 
     for concrete_function in self.concrete_functions:
+      if not concrete_function.graph.saveable:
+        raise ValueError(
+            ("Unable to save function {name} for the following reason(s):\n" +
+             "\n".join(concrete_function.graph.saving_errors))
+            .format(name=concrete_function.name))
       for capture in concrete_function.captured_inputs:
         if (tensor_util.is_tensor(capture)
             and capture.dtype not in _UNCOPIABLE_DTYPES
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 566c508..a5200b0 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -141,6 +141,22 @@
       save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
                 signatures=root.f)
 
+  def test_unsaveable_func_graph(self):
+    root = module.Module()
+
+    @def_function.function(input_signature=[])
+    def nested_f():
+      ops.get_default_graph().mark_as_unsaveable("ERROR MSG")
+      return 1
+
+    @def_function.function(input_signature=[])
+    def f():
+      return nested_f()
+
+    root.f = f
+    with self.assertRaisesRegexp(ValueError, "ERROR MSG"):
+      save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
+
   def test_version_information_included(self):
     root = tracking.AutoTrackable()
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")