Support calling `model.build()` with a shape passed a list of int/None for
subclassed models.
PiperOrigin-RevId: 323846785
Change-Id: I7f3260b4a3309527f665ef04b608a5c00a15c352
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 18dfc4c..15f77ab 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -393,6 +393,9 @@
else:
graph = backend.get_graph()
with graph.as_default():
+ if (isinstance(input_shape, list) and
+ all(d is None or isinstance(d, int) for d in input_shape)):
+ input_shape = tuple(input_shape)
if isinstance(input_shape, list):
x = [base_layer_utils.generate_placeholders_from_shape(shape)
for shape in input_shape]
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 93e9b66..15976c0 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -3604,5 +3604,52 @@
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
+class TestBuildCustomModel(keras_parameterized.TestCase):
+
+ @keras_parameterized.run_all_keras_modes
+ def test_build_list_of_inputs(self):
+
+ class MyModel(training_module.Model):
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.l1 = layers_module.Dense(1)
+ self.l2 = layers_module.Dense(2)
+
+ def call(self, x):
+ a, b = x
+ return self.l1(a) + self.l2(b)
+
+ # List of tuples
+ model = MyModel()
+ model.build([(None, 1), (None, 2)])
+ self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+ self.assertEqual(model.l2.kernel.shape.as_list(), [2, 2])
+ # List of lists
+ model = MyModel()
+ model.build([[None, 1], [None, 2]])
+ self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+ self.assertEqual(model.l2.kernel.shape.as_list(), [2, 2])
+
+ @keras_parameterized.run_all_keras_modes
+ def test_build_single_inputs(self):
+
+ class MyModel(training_module.Model):
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.l1 = layers_module.Dense(1)
+
+ def call(self, x):
+ return self.l1(x)
+
+ model = MyModel()
+ model.build((None, 1))
+ self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+ model = MyModel()
+ model.build([None, 1])
+ self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+
+
if __name__ == '__main__':
test.main()