Support calling `model.build()` with a shape passed a list of int/None for
subclassed models.

PiperOrigin-RevId: 323846785
Change-Id: I7f3260b4a3309527f665ef04b608a5c00a15c352
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 18dfc4c..15f77ab 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -393,6 +393,9 @@
       else:
         graph = backend.get_graph()
       with graph.as_default():
+        if (isinstance(input_shape, list) and
+            all(d is None or isinstance(d, int) for d in input_shape)):
+          input_shape = tuple(input_shape)
         if isinstance(input_shape, list):
           x = [base_layer_utils.generate_placeholders_from_shape(shape)
                for shape in input_shape]
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 93e9b66..15976c0 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -3604,5 +3604,52 @@
     self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
 
 
+class TestBuildCustomModel(keras_parameterized.TestCase):
+
+  @keras_parameterized.run_all_keras_modes
+  def test_build_list_of_inputs(self):
+
+    class MyModel(training_module.Model):
+
+      def __init__(self):
+        super(MyModel, self).__init__()
+        self.l1 = layers_module.Dense(1)
+        self.l2 = layers_module.Dense(2)
+
+      def call(self, x):
+        a, b = x
+        return self.l1(a) + self.l2(b)
+
+    # List of tuples
+    model = MyModel()
+    model.build([(None, 1), (None, 2)])
+    self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+    self.assertEqual(model.l2.kernel.shape.as_list(), [2, 2])
+    # List of lists
+    model = MyModel()
+    model.build([[None, 1], [None, 2]])
+    self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+    self.assertEqual(model.l2.kernel.shape.as_list(), [2, 2])
+
+  @keras_parameterized.run_all_keras_modes
+  def test_build_single_inputs(self):
+
+    class MyModel(training_module.Model):
+
+      def __init__(self):
+        super(MyModel, self).__init__()
+        self.l1 = layers_module.Dense(1)
+
+      def call(self, x):
+        return self.l1(x)
+
+    model = MyModel()
+    model.build((None, 1))
+    self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+    model = MyModel()
+    model.build([None, 1])
+    self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
+
+
 if __name__ == '__main__':
   test.main()