Update Sequential model to raise error for multi-output layer in the deferred mode.

This is a update for #36624, which we should show explicit error, rather than let the code proceed and failed down the road.

PiperOrigin-RevId: 295248095
Change-Id: I9cf9d0267f222c3994f3199bd9b49d86196ebb3b
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 8609449..a86084f 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -39,6 +39,11 @@
 from tensorflow.python.util.tf_export import keras_export
 
 
+SINGLE_LAYER_OUTPUT_ERROR_MSG = ('All layers in a Sequential model should have '
+                                 'a single output tensor. For multi-output '
+                                 'layers, use the functional API.')
+
+
 @keras_export('keras.Sequential', 'keras.models.Sequential')
 class Sequential(training.Model):
   """`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
@@ -195,10 +200,7 @@
       if set_inputs:
         # If an input layer (placeholder) is available.
         if len(nest.flatten(layer._inbound_nodes[-1].output_tensors)) != 1:
-          raise ValueError('All layers in a Sequential model '
-                           'should have a single output tensor. '
-                           'For multi-output layers, '
-                           'use the functional API.')
+          raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
         self.outputs = [
             nest.flatten(layer._inbound_nodes[-1].output_tensors)[0]
         ]
@@ -209,10 +211,7 @@
       # refresh its output.
       output_tensor = layer(self.outputs[0])
       if len(nest.flatten(output_tensor)) != 1:
-        raise TypeError('All layers in a Sequential model '
-                        'should have a single output tensor. '
-                        'For multi-output layers, '
-                        'use the functional API.')
+        raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
       self.outputs = [output_tensor]
 
     if self.outputs:
@@ -286,6 +285,8 @@
 
       outputs = layer(inputs, **kwargs)
 
+      if len(nest.flatten(outputs)) != 1:
+        raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
       # `outputs` will be the inputs to the next layer.
       inputs = outputs
       mask = outputs._keras_mask
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index fa516eb..65e58fd 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -215,23 +215,6 @@
       model = keras.models.Sequential()
       model.add(None)
 
-    # Added layers cannot have multiple outputs
-    class MyLayer(keras.layers.Layer):
-
-      def call(self, inputs):
-        return [3 * inputs, 2 * inputs]
-
-      def compute_output_shape(self, input_shape):
-        return [input_shape, input_shape]
-
-    with self.assertRaises(ValueError):
-      model = keras.models.Sequential()
-      model.add(MyLayer(input_shape=(3,)))
-    with self.assertRaises(TypeError):
-      model = keras.models.Sequential()
-      model.add(keras.layers.Dense(1, input_dim=1))
-      model.add(MyLayer())
-
   @keras_parameterized.run_all_keras_modes
   def test_nested_sequential_trainability(self):
     input_dim = 20
@@ -405,6 +388,17 @@
         ValueError, 'should have a single output tensor'):
       keras.Sequential([MultiOutputLayer(input_shape=(3,))])
 
+    with self.assertRaisesRegexp(
+        ValueError, 'should have a single output tensor'):
+      keras.Sequential([
+          keras.layers.Dense(1, input_shape=(3,)),
+          MultiOutputLayer()])
+
+    # Should also raise error in a deferred build mode
+    with self.assertRaisesRegexp(
+        ValueError, 'should have a single output tensor'):
+      keras.Sequential([MultiOutputLayer()])(np.zeros((10, 10)))
+
   @keras_parameterized.run_all_keras_modes
   def test_layer_add_after_compile_deferred(self):
     model = keras.Sequential([keras.layers.Dense(3)])