Rollback of eb7946c4752f7babf3206f3faa13db13735cc08c

PiperOrigin-RevId: 277108294
Change-Id: I1586c5af06b7a7748199dc86a088e3be09eb27ae
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 8202931..f1bb1de 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -145,12 +145,8 @@
 
   def build(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape)
-    channel_axis = self._get_channel_axis()
-    if input_shape.dims[channel_axis].value is None:
-      raise ValueError('The channel dimension of the inputs '
-                       'should be defined. Found `None`.')
-    input_dim = int(input_shape[channel_axis])
-    kernel_shape = self.kernel_size + (input_dim, self.filters)
+    input_channel = self._get_input_channel(input_shape)
+    kernel_shape = self.kernel_size + (input_channel, self.filters)
 
     self.kernel = self.add_weight(
         name='kernel',
@@ -171,19 +167,51 @@
           dtype=self.dtype)
     else:
       self.bias = None
+    channel_axis = self._get_channel_axis()
     self.input_spec = InputSpec(ndim=self.rank + 2,
-                                axes={channel_axis: input_dim})
+                                axes={channel_axis: input_channel})
+
+    self._build_conv_op_input_shape = input_shape
+    self._build_input_channel = input_channel
+    self._padding_op = self._get_padding_op()
+    self._conv_op_data_format = conv_utils.convert_data_format(
+        self.data_format, self.rank + 2)
     self._convolution_op = nn_ops.Convolution(
         input_shape,
         filter_shape=self.kernel.shape,
         dilation_rate=self.dilation_rate,
         strides=self.strides,
-        padding=self._get_padding_op(),
-        data_format=conv_utils.convert_data_format(self.data_format,
-                                                   self.rank + 2))
+        padding=self._padding_op,
+        data_format=self._conv_op_data_format)
     self.built = True
 
   def call(self, inputs):
+    # Check if the input_shape in call() is different from that in build().
+    # If they are different, recreate the _convolution_op to avoid the stateful
+    # behavior.
+    call_input_shape = inputs.get_shape()
+    call_input_channel = self._get_input_channel(call_input_shape)
+    if call_input_channel != self._build_input_channel:
+      raise ValueError(
+          'Expected input data with {} channels (in format {}), but got inputs '
+          'with shape: {}'.format(self._build_input_channel, self.data_format,
+                                  call_input_shape))
+    recreate_conv_op = (
+        call_input_shape[1:] != self._build_conv_op_input_shape[1:])
+
+    if recreate_conv_op:
+      self._convolution_op = nn_ops.Convolution(
+          call_input_shape,
+          filter_shape=self.kernel.shape,
+          dilation_rate=self.dilation_rate,
+          strides=self.strides,
+          padding=self._padding_op,
+          data_format=self._conv_op_data_format)
+
+    # Apply causal padding to inputs for Conv1D.
+    if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
+      inputs = array_ops.pad(inputs, self._compute_causal_padding())
+
     outputs = self._convolution_op(inputs, self.kernel)
 
     if self.use_bias:
@@ -267,6 +295,13 @@
     else:
       return -1
 
+  def _get_input_channel(self, input_shape):
+    channel_axis = self._get_channel_axis()
+    if input_shape.dims[channel_axis].value is None:
+      raise ValueError('The channel dimension of the inputs '
+                       'should be defined. Found `None`.')
+    return int(input_shape[channel_axis])
+
   def _get_padding_op(self):
     if self.padding == 'causal':
       op_padding = 'valid'
@@ -386,11 +421,6 @@
         bias_constraint=constraints.get(bias_constraint),
         **kwargs)
 
-  def call(self, inputs):
-    if self.padding == 'causal':
-      inputs = array_ops.pad(inputs, self._compute_causal_padding())
-    return super(Conv1D, self).call(inputs)
-
 
 @keras_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
 class Conv2D(Conv):
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index 22127b0..33e887f 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -259,6 +259,66 @@
 
 
 @keras_parameterized.run_all_keras_modes
+class ConvSequentialTest(keras_parameterized.TestCase):
+
+  def _run_test(self, conv_layer_cls, kwargs, input_shape1, input_shape2,
+                expected_output_shape1, expected_output_shape2):
+    kwargs['filters'] = 1
+    kwargs['kernel_size'] = 3
+    kwargs['dilation_rate'] = 2
+    with self.cached_session(use_gpu=True):
+      layer = conv_layer_cls(**kwargs)
+      output1 = layer(np.zeros(input_shape1))
+      self.assertEqual(output1.shape, expected_output_shape1)
+      output2 = layer(np.zeros(input_shape2))
+      self.assertEqual(output2.shape, expected_output_shape2)
+
+  @parameterized.named_parameters(
+      ('padding_valid', {'padding': 'valid'},
+       (1, 8, 2), (1, 5, 2), (1, 4, 1), (1, 1, 1)),
+      ('padding_same', {'padding': 'same'},
+       (1, 8, 2), (1, 5, 2), (1, 8, 1), (1, 5, 1)),
+      ('padding_causal', {'padding': 'causal'},
+       (1, 8, 2), (1, 5, 2), (1, 8, 1), (1, 5, 1)),
+  )
+  def test_conv1d(self, kwargs, input_shape1, input_shape2,
+                  expected_output_shape1, expected_output_shape2):
+    self._run_test(keras.layers.Conv1D, kwargs, input_shape1, input_shape2,
+                   expected_output_shape1, expected_output_shape2)
+
+  @parameterized.named_parameters(
+      ('padding_valid', {'padding': 'valid'},
+       (1, 7, 6, 2), (1, 6, 5, 2), (1, 3, 2, 1), (1, 2, 1, 1)),
+      ('padding_same', {'padding': 'same'},
+       (1, 7, 6, 2), (1, 6, 5, 2), (1, 7, 6, 1), (1, 6, 5, 1)),
+  )
+  def test_conv2d(self, kwargs, input_shape1, input_shape2,
+                  expected_output_shape1, expected_output_shape2):
+    self._run_test(keras.layers.Conv2D, kwargs, input_shape1, input_shape2,
+                   expected_output_shape1, expected_output_shape2)
+
+  @parameterized.named_parameters(
+      ('padding_valid', {'padding': 'valid'},
+       (1, 5, 7, 6, 2), (1, 8, 6, 5, 2), (1, 1, 3, 2, 1), (1, 4, 2, 1, 1)),
+      ('padding_same', {'padding': 'same'},
+       (1, 5, 7, 6, 2), (1, 8, 6, 5, 2), (1, 5, 7, 6, 1), (1, 8, 6, 5, 1)),
+  )
+  def test_conv3d(self, kwargs, input_shape1, input_shape2,
+                  expected_output_shape1, expected_output_shape2):
+    self._run_test(keras.layers.Conv3D, kwargs, input_shape1, input_shape2,
+                   expected_output_shape1, expected_output_shape2)
+
+  def test_invalid_channel_dim(self):
+    with self.cached_session(use_gpu=True):
+      layer = keras.layers.Conv1D(1, 3, dilation_rate=2, padding='valid')
+      output1 = layer(np.zeros((1, 8, 2)))
+      self.assertEqual(output1.shape, (1, 4, 1))
+      with self.assertRaisesRegexp(
+          ValueError, 'Expected input data with 2 channels'):
+        layer(np.zeros((1, 5, 3)))
+
+
+@keras_parameterized.run_all_keras_modes
 class ZeroPaddingTest(keras_parameterized.TestCase):
 
   def test_zero_padding_1d(self):