Generalizes the first argument in keras layers further. Now, functional models get constructed if *any* tensor in the arguments or keyword arguments has a keras history, rather than if *all* of the elements in the first argument to the layer do.

PiperOrigin-RevId: 313718130
Change-Id: I77f65f49decf45f6a2b53ab0519d6d2ac38232d3
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 9958f70..0630199 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -830,14 +830,13 @@
     in_call = call_context.in_call
     input_list = nest.flatten(inputs)
 
-    # We will attempt to build a TF graph if & only if all inputs are symbolic.
-    # This is always the case in graph mode. It can also be the case in eager
-    # mode when all inputs can be traced back to `keras.Input()` (when building
-    # models using the functional API).
-    # TODO(kaftan): make this not special case inputs. Instead
-    # build a functional api model if *any* *arg or **kwarg is symbolic,
-    # even if part of the data structure in that arg is not symbolic.
-    build_graph = tf_utils.are_all_symbolic_tensors(input_list)
+    # We will attempt to trace in a graph if & only if inputs are symbolic.
+    # This is always the case when tracing a function. It can also be the case
+    # when running eagerly if any input can be traced back to `keras.Input()`
+    # (when building models using the functional API).
+    build_graph = tf_utils.are_all_symbolic_tensors(input_list) or (
+        any(map(tf_utils.is_symbolic_tensor, nest.flatten(
+            [input_list, args, kwargs]))) and context.executing_eagerly())
 
     # Accept NumPy and scalar inputs by converting to Tensors.
     if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
@@ -890,11 +889,14 @@
             'training', training_value, args, kwargs)
         training_arg_passed_by_framework = True
 
-    # Only create Keras history if at least one tensor originates from a
-    # `keras.Input`. Otherwise this Layer may be being used outside the Keras
-    # framework.
-    # TODO(kaftan): make this not special case inputs
-    if build_graph and base_layer_utils.needs_keras_history(inputs):
+    # Turn inputs into TF op layers if necessary.
+    # This process is fragile and prone to bad interactions with inputs
+    # when calling nested layers with tf.functions floating around,
+    # and with nonsymbolic tensors.
+    # So, we limit it to the
+    # case where *all* inputs in the first arg are symbolic.
+    if (tf_utils.are_all_symbolic_tensors(input_list)
+        and base_layer_utils.needs_keras_history(inputs)):
       base_layer_utils.create_keras_history(inputs)
 
     with call_context.enter(self, inputs, build_graph, training_value):
@@ -968,8 +970,12 @@
             raise ValueError('A layer\'s `call` method should return a '
                              'Tensor or a list of Tensors, not None '
                              '(layer: ' + self.name + ').')
-          # TODO(kaftan): This should be 'any' and check all args
-          if base_layer_utils.have_all_keras_metadata(inputs):
+          # We configure connectivity metadata if all inputs in the first
+          # arg have keras history, or if we're actively building the
+          # functional api outside of any outer keras model.
+          if base_layer_utils.have_all_keras_metadata(inputs) or (
+              context.executing_eagerly() and
+              base_layer_utils.have_any_keras_metadata(inputs, args, kwargs)):
             if training_arg_passed_by_framework:
               args, kwargs = self._set_call_arg_value(
                   'training', None, args, kwargs, pop_kwarg_if_none=True)
diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
index 6d25995..6508d64 100644
--- a/tensorflow/python/keras/engine/base_layer_utils.py
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -165,6 +165,10 @@
   return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
 
 
+def have_any_keras_metadata(*tensors):
+  return any(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
+
+
 def generate_placeholders_from_shape(shape):
   return array_ops.placeholder(shape=shape, dtype=backend.floatx())
 
@@ -214,7 +218,10 @@
   for tensor in tensor_list:
     if getattr(tensor, '_keras_history', None) is not None:
       continue
-    op = tensor.op  # The Op that created this Tensor.
+    try:
+      op = tensor.op  # The Op that created this Tensor.
+    except AttributeError:
+      continue
     if op not in processed_ops:
       if op.type.startswith('Sparse'):
         lambda_example = """
@@ -392,7 +399,10 @@
   """
 
   def _mark_checked(tensor):
-    tensor._keras_history_checked = True  # pylint: disable=protected-access
+    try:
+      tensor._keras_history_checked = True  # pylint: disable=protected-access
+    except AttributeError:
+      pass
 
   nest.map_structure(_mark_checked, tensors)
 
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 4958990..eec1334 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -32,6 +32,7 @@
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import input_layer as input_layer_module
+from tensorflow.python.keras.engine import node as node_module
 from tensorflow.python.keras.engine import training as training_lib
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.saving.saved_model import network_serialization
@@ -1111,19 +1112,28 @@
         kwargs = {}
       elif len(input_data) == 4:
         kwargs = input_data[3]
-        kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+        try:
+          kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+        except IndexError:
+          # Happens if keras tensors in kwargs are still unprocessed
+          add_unprocessed_node(layer, node_data)
+          return
       else:
         raise ValueError('Improperly formatted model config.')
 
-      inbound_layer = created_layers[inbound_layer_name]
-      inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
+      if inbound_layer_name != node_module._CONSTANT_VALUE:
+        inbound_layer = created_layers[inbound_layer_name]
+        inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
 
-      if inbound_node_index is None:
-        add_unprocessed_node(layer, node_data)
-        return
-      inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
-      input_tensors.append(
-          nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+        if inbound_node_index is None:
+          add_unprocessed_node(layer, node_data)
+          return
+        inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
+        input_tensors.append(
+            nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+      else:
+        # We received a constant w/ no Keras history attached
+        input_tensors.append(inbound_tensor_index)
     input_tensors = nest.pack_sequence_as(node_data, input_tensors)
     # Call layer on its inputs, thus creating the node
     # and building the layer if needed.
diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py
index 90fc9f2..e975bb8 100644
--- a/tensorflow/python/keras/engine/functional_test.py
+++ b/tensorflow/python/keras/engine/functional_test.py
@@ -965,6 +965,43 @@
     self.assertEqual(history.history['loss'][0], 0.0)
 
   @combinations.generate(combinations.keras_mode_combinations())
+  def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
+
+    class MaybeAdd(layers.Layer):
+
+      def call(self, x1, x2=None):
+        if x2 is not None:
+          return x1 + x2
+        return x1
+
+    input2 = input_layer_lib.Input(10)
+    outputs = MaybeAdd()(3., x2=input2)
+    model = training_lib.Model([input2], outputs)
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    history = model.fit(
+        x=7 * np.ones((10, 10)),
+        y=10 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that second input was correctly added to first.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+    model = training_lib.Model.from_config(
+        model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    history = model.fit(
+        x=7 * np.ones((10, 10)),
+        y=10 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that second input was correctly added to first.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+  @combinations.generate(combinations.keras_mode_combinations())
   def test_composite_call_kwarg_derived_from_keras_layer(self):
 
     # Create a test layer that accepts composite tensor inputs.
@@ -1005,6 +1042,56 @@
     # Check that second input was correctly added to first.
     self.assertEqual(history.history['loss'][0], 0.0)
 
+  @combinations.generate(combinations.keras_mode_combinations(mode='eager'))
+  def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
+    # This functionality is unsupported in v1 graphs
+
+    class AddAll(layers.Layer):
+
+      def call(self, x1_x2, x3):
+        x1, x2 = x1_x2
+        out = x1 + x2
+        if x3 is not None:
+          for t in x3.values():
+            out += t
+        return out
+
+    input1 = input_layer_lib.Input(10)
+    input2 = input_layer_lib.Input(10)
+    input3 = input_layer_lib.Input(10)
+
+    outputs = AddAll()(
+        [input1, 4 * array_ops.ones((1, 10))],
+        x3={
+            'a': input2,
+            'b': input3,
+            'c': 5 * array_ops.ones((1, 10))
+        })
+    model = training_lib.Model([input1, input2, input3], outputs)
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    history = model.fit(
+        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
+        y=15 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that all inputs were correctly added.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+    model = training_lib.Model.from_config(
+        model.get_config(), custom_objects={'AddAll': AddAll})
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    history = model.fit(
+        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
+        y=15 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that all inputs were correctly added.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
   @combinations.generate(combinations.keras_mode_combinations())
   def test_call_nested_arg_derived_from_keras_layer(self):
 
diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py
index 7089048..1637e05 100644
--- a/tensorflow/python/keras/engine/node.py
+++ b/tensorflow/python/keras/engine/node.py
@@ -32,6 +32,8 @@
 from tensorflow.python.util import nest
 from tensorflow.python.util import serialization
 
+_CONSTANT_VALUE = '_CONSTANT_VALUE'
+
 
 class Node(object):
   """A `Node` describes the connectivity between two layers.
@@ -181,11 +183,14 @@
     # `kwargs` is added to each Tensor in the first arg. This should be
     # changed in a future version of the serialization format.
     def serialize_first_arg_tensor(t):
-      kh = t._keras_history
-      node_index = kh.node_index
-      node_key = make_node_key(kh.layer.name, node_index)
-      new_node_index = node_conversion_map.get(node_key, 0)
-      data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+      if is_keras_tensor(t):
+        kh = t._keras_history
+        node_index = kh.node_index
+        node_key = make_node_key(kh.layer.name, node_index)
+        new_node_index = node_conversion_map.get(node_key, 0)
+        data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+      else:
+        data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
       return tf_utils.ListWrapper(data)
 
     data = nest.map_structure(serialize_first_arg_tensor, inputs)