Generalize functional model serialization to support cases where elements in the first argument of a layer call are not produced from keras inputs. This will be needed to support generalizing functional model construction to cases where only elements outside of the first call argument have keras history.

PiperOrigin-RevId: 313820929
Change-Id: Ie9ef0cbc6b8caab189534faf09de897c361f5c08
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 4958990..eec1334 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -32,6 +32,7 @@
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import input_layer as input_layer_module
+from tensorflow.python.keras.engine import node as node_module
 from tensorflow.python.keras.engine import training as training_lib
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.saving.saved_model import network_serialization
@@ -1111,19 +1112,28 @@
         kwargs = {}
       elif len(input_data) == 4:
         kwargs = input_data[3]
-        kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+        try:
+          kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+        except IndexError:
+          # Happens if keras tensors in kwargs are still unprocessed
+          add_unprocessed_node(layer, node_data)
+          return
       else:
         raise ValueError('Improperly formatted model config.')
 
-      inbound_layer = created_layers[inbound_layer_name]
-      inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
+      if inbound_layer_name != node_module._CONSTANT_VALUE:
+        inbound_layer = created_layers[inbound_layer_name]
+        inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
 
-      if inbound_node_index is None:
-        add_unprocessed_node(layer, node_data)
-        return
-      inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
-      input_tensors.append(
-          nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+        if inbound_node_index is None:
+          add_unprocessed_node(layer, node_data)
+          return
+        inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
+        input_tensors.append(
+            nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+      else:
+        # We received a constant w/ no Keras history attached
+        input_tensors.append(inbound_tensor_index)
     input_tensors = nest.pack_sequence_as(node_data, input_tensors)
     # Call layer on its inputs, thus creating the node
     # and building the layer if needed.
diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py
index 7089048..e8d9838 100644
--- a/tensorflow/python/keras/engine/node.py
+++ b/tensorflow/python/keras/engine/node.py
@@ -32,6 +32,8 @@
 from tensorflow.python.util import nest
 from tensorflow.python.util import serialization
 
+_CONSTANT_VALUE = '_CONSTANT_VALUE'
+
 
 class Node(object):
   """A `Node` describes the connectivity between two layers.
@@ -181,11 +183,18 @@
     # `kwargs` is added to each Tensor in the first arg. This should be
     # changed in a future version of the serialization format.
     def serialize_first_arg_tensor(t):
-      kh = t._keras_history
-      node_index = kh.node_index
-      node_key = make_node_key(kh.layer.name, node_index)
-      new_node_index = node_conversion_map.get(node_key, 0)
-      data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+      if is_keras_tensor(t):
+        kh = t._keras_history
+        node_index = kh.node_index
+        node_key = make_node_key(kh.layer.name, node_index)
+        new_node_index = node_conversion_map.get(node_key, 0)
+        data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+      else:
+        # If an element in the first call argument did not originate as a
+        # keras tensor and is a constant value, we save it using the format
+        # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant]
+        # (potentially including serialized kwargs in an optional 4th argument
+        data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
       return tf_utils.ListWrapper(data)
 
     data = nest.map_structure(serialize_first_arg_tensor, inputs)