Generalizes the first argument in keras layers further. Now, functional models get constructed if *any* tensor in the arguments or keyword arguments has a keras history, rather than if *all* of the elements in the first argument to the layer do.
PiperOrigin-RevId: 313718130
Change-Id: I77f65f49decf45f6a2b53ab0519d6d2ac38232d3
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 9958f70..0630199 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -830,14 +830,13 @@
in_call = call_context.in_call
input_list = nest.flatten(inputs)
- # We will attempt to build a TF graph if & only if all inputs are symbolic.
- # This is always the case in graph mode. It can also be the case in eager
- # mode when all inputs can be traced back to `keras.Input()` (when building
- # models using the functional API).
- # TODO(kaftan): make this not special case inputs. Instead
- # build a functional api model if *any* *arg or **kwarg is symbolic,
- # even if part of the data structure in that arg is not symbolic.
- build_graph = tf_utils.are_all_symbolic_tensors(input_list)
+ # We will attempt to trace in a graph if & only if inputs are symbolic.
+ # This is always the case when tracing a function. It can also be the case
+ # when running eagerly if any input can be traced back to `keras.Input()`
+ # (when building models using the functional API).
+ build_graph = tf_utils.are_all_symbolic_tensors(input_list) or (
+ any(map(tf_utils.is_symbolic_tensor, nest.flatten(
+ [input_list, args, kwargs]))) and context.executing_eagerly())
# Accept NumPy and scalar inputs by converting to Tensors.
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
@@ -890,11 +889,14 @@
'training', training_value, args, kwargs)
training_arg_passed_by_framework = True
- # Only create Keras history if at least one tensor originates from a
- # `keras.Input`. Otherwise this Layer may be being used outside the Keras
- # framework.
- # TODO(kaftan): make this not special case inputs
- if build_graph and base_layer_utils.needs_keras_history(inputs):
+ # Turn inputs into TF op layers if necessary.
+ # This process is fragile and prone to bad interactions with inputs
+ # when calling nested layers with tf.functions floating around,
+ # and with nonsymbolic tensors.
+ # So, we limit it to the
+ # case where *all* inputs in the first arg are symbolic.
+ if (tf_utils.are_all_symbolic_tensors(input_list)
+ and base_layer_utils.needs_keras_history(inputs)):
base_layer_utils.create_keras_history(inputs)
with call_context.enter(self, inputs, build_graph, training_value):
@@ -968,8 +970,12 @@
raise ValueError('A layer\'s `call` method should return a '
'Tensor or a list of Tensors, not None '
'(layer: ' + self.name + ').')
- # TODO(kaftan): This should be 'any' and check all args
- if base_layer_utils.have_all_keras_metadata(inputs):
+ # We configure connectivity metadata if all inputs in the first
+ # arg have keras history, or if we're actively building the
+ # functional api outside of any outer keras model.
+ if base_layer_utils.have_all_keras_metadata(inputs) or (
+ context.executing_eagerly() and
+ base_layer_utils.have_any_keras_metadata(inputs, args, kwargs)):
if training_arg_passed_by_framework:
args, kwargs = self._set_call_arg_value(
'training', None, args, kwargs, pop_kwarg_if_none=True)
diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
index 6d25995..6508d64 100644
--- a/tensorflow/python/keras/engine/base_layer_utils.py
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -165,6 +165,10 @@
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
+def have_any_keras_metadata(*tensors):
+ return any(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
+
+
def generate_placeholders_from_shape(shape):
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
@@ -214,7 +218,10 @@
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
- op = tensor.op # The Op that created this Tensor.
+ try:
+ op = tensor.op # The Op that created this Tensor.
+ except AttributeError:
+ continue
if op not in processed_ops:
if op.type.startswith('Sparse'):
lambda_example = """
@@ -392,7 +399,10 @@
"""
def _mark_checked(tensor):
- tensor._keras_history_checked = True # pylint: disable=protected-access
+ try:
+ tensor._keras_history_checked = True # pylint: disable=protected-access
+ except AttributeError:
+ pass
nest.map_structure(_mark_checked, tensors)
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 4958990..eec1334 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -32,6 +32,7 @@
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_layer as input_layer_module
+from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.saving.saved_model import network_serialization
@@ -1111,19 +1112,28 @@
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
- kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+ try:
+ kwargs = _deserialize_keras_tensors(kwargs, created_layers)
+ except IndexError:
+ # Happens if keras tensors in kwargs are still unprocessed
+ add_unprocessed_node(layer, node_data)
+ return
else:
raise ValueError('Improperly formatted model config.')
- inbound_layer = created_layers[inbound_layer_name]
- inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
+ if inbound_layer_name != node_module._CONSTANT_VALUE:
+ inbound_layer = created_layers[inbound_layer_name]
+ inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
- if inbound_node_index is None:
- add_unprocessed_node(layer, node_data)
- return
- inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
- input_tensors.append(
- nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+ if inbound_node_index is None:
+ add_unprocessed_node(layer, node_data)
+ return
+ inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
+ input_tensors.append(
+ nest.flatten(inbound_node.outputs)[inbound_tensor_index])
+ else:
+ # We received a constant w/ no Keras history attached
+ input_tensors.append(inbound_tensor_index)
input_tensors = nest.pack_sequence_as(node_data, input_tensors)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py
index 90fc9f2..e975bb8 100644
--- a/tensorflow/python/keras/engine/functional_test.py
+++ b/tensorflow/python/keras/engine/functional_test.py
@@ -965,6 +965,43 @@
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
+ def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
+
+ class MaybeAdd(layers.Layer):
+
+ def call(self, x1, x2=None):
+ if x2 is not None:
+ return x1 + x2
+ return x1
+
+ input2 = input_layer_lib.Input(10)
+ outputs = MaybeAdd()(3., x2=input2)
+ model = training_lib.Model([input2], outputs)
+ model.compile(
+ 'sgd',
+ 'mse',
+ run_eagerly=testing_utils.should_run_eagerly())
+ history = model.fit(
+ x=7 * np.ones((10, 10)),
+ y=10 * np.ones((10, 10)),
+ batch_size=2)
+ # Check that second input was correctly added to first.
+ self.assertEqual(history.history['loss'][0], 0.0)
+
+ model = training_lib.Model.from_config(
+ model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
+ model.compile(
+ 'sgd',
+ 'mse',
+ run_eagerly=testing_utils.should_run_eagerly())
+ history = model.fit(
+ x=7 * np.ones((10, 10)),
+ y=10 * np.ones((10, 10)),
+ batch_size=2)
+ # Check that second input was correctly added to first.
+ self.assertEqual(history.history['loss'][0], 0.0)
+
+ @combinations.generate(combinations.keras_mode_combinations())
def test_composite_call_kwarg_derived_from_keras_layer(self):
# Create a test layer that accepts composite tensor inputs.
@@ -1005,6 +1042,56 @@
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
+ @combinations.generate(combinations.keras_mode_combinations(mode='eager'))
+ def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
+ # This functionality is unsupported in v1 graphs
+
+ class AddAll(layers.Layer):
+
+ def call(self, x1_x2, x3):
+ x1, x2 = x1_x2
+ out = x1 + x2
+ if x3 is not None:
+ for t in x3.values():
+ out += t
+ return out
+
+ input1 = input_layer_lib.Input(10)
+ input2 = input_layer_lib.Input(10)
+ input3 = input_layer_lib.Input(10)
+
+ outputs = AddAll()(
+ [input1, 4 * array_ops.ones((1, 10))],
+ x3={
+ 'a': input2,
+ 'b': input3,
+ 'c': 5 * array_ops.ones((1, 10))
+ })
+ model = training_lib.Model([input1, input2, input3], outputs)
+ model.compile(
+ 'sgd',
+ 'mse',
+ run_eagerly=testing_utils.should_run_eagerly())
+ history = model.fit(
+ x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
+ y=15 * np.ones((10, 10)),
+ batch_size=2)
+ # Check that all inputs were correctly added.
+ self.assertEqual(history.history['loss'][0], 0.0)
+
+ model = training_lib.Model.from_config(
+ model.get_config(), custom_objects={'AddAll': AddAll})
+ model.compile(
+ 'sgd',
+ 'mse',
+ run_eagerly=testing_utils.should_run_eagerly())
+ history = model.fit(
+ x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
+ y=15 * np.ones((10, 10)),
+ batch_size=2)
+ # Check that all inputs were correctly added.
+ self.assertEqual(history.history['loss'][0], 0.0)
+
@combinations.generate(combinations.keras_mode_combinations())
def test_call_nested_arg_derived_from_keras_layer(self):
diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py
index 7089048..1637e05 100644
--- a/tensorflow/python/keras/engine/node.py
+++ b/tensorflow/python/keras/engine/node.py
@@ -32,6 +32,8 @@
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
+_CONSTANT_VALUE = '_CONSTANT_VALUE'
+
class Node(object):
"""A `Node` describes the connectivity between two layers.
@@ -181,11 +183,14 @@
# `kwargs` is added to each Tensor in the first arg. This should be
# changed in a future version of the serialization format.
def serialize_first_arg_tensor(t):
- kh = t._keras_history
- node_index = kh.node_index
- node_key = make_node_key(kh.layer.name, node_index)
- new_node_index = node_conversion_map.get(node_key, 0)
- data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+ if is_keras_tensor(t):
+ kh = t._keras_history
+ node_index = kh.node_index
+ node_key = make_node_key(kh.layer.name, node_index)
+ new_node_index = node_conversion_map.get(node_key, 0)
+ data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
+ else:
+ data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
return tf_utils.ListWrapper(data)
data = nest.map_structure(serialize_first_arg_tensor, inputs)