Rollback of eb7946c4752f7babf3206f3faa13db13735cc08c
PiperOrigin-RevId: 277108294
Change-Id: I1586c5af06b7a7748199dc86a088e3be09eb27ae
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 8202931..f1bb1de 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -145,12 +145,8 @@
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
- channel_axis = self._get_channel_axis()
- if input_shape.dims[channel_axis].value is None:
- raise ValueError('The channel dimension of the inputs '
- 'should be defined. Found `None`.')
- input_dim = int(input_shape[channel_axis])
- kernel_shape = self.kernel_size + (input_dim, self.filters)
+ input_channel = self._get_input_channel(input_shape)
+ kernel_shape = self.kernel_size + (input_channel, self.filters)
self.kernel = self.add_weight(
name='kernel',
@@ -171,19 +167,51 @@
dtype=self.dtype)
else:
self.bias = None
+ channel_axis = self._get_channel_axis()
self.input_spec = InputSpec(ndim=self.rank + 2,
- axes={channel_axis: input_dim})
+ axes={channel_axis: input_channel})
+
+ self._build_conv_op_input_shape = input_shape
+ self._build_input_channel = input_channel
+ self._padding_op = self._get_padding_op()
+ self._conv_op_data_format = conv_utils.convert_data_format(
+ self.data_format, self.rank + 2)
self._convolution_op = nn_ops.Convolution(
input_shape,
filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate,
strides=self.strides,
- padding=self._get_padding_op(),
- data_format=conv_utils.convert_data_format(self.data_format,
- self.rank + 2))
+ padding=self._padding_op,
+ data_format=self._conv_op_data_format)
self.built = True
def call(self, inputs):
+ # Check if the input_shape in call() is different from that in build().
+ # If they are different, recreate the _convolution_op to avoid the stateful
+ # behavior.
+ call_input_shape = inputs.get_shape()
+ call_input_channel = self._get_input_channel(call_input_shape)
+ if call_input_channel != self._build_input_channel:
+ raise ValueError(
+ 'Expected input data with {} channels (in format {}), but got inputs '
+ 'with shape: {}'.format(self._build_input_channel, self.data_format,
+ call_input_shape))
+ recreate_conv_op = (
+ call_input_shape[1:] != self._build_conv_op_input_shape[1:])
+
+ if recreate_conv_op:
+ self._convolution_op = nn_ops.Convolution(
+ call_input_shape,
+ filter_shape=self.kernel.shape,
+ dilation_rate=self.dilation_rate,
+ strides=self.strides,
+ padding=self._padding_op,
+ data_format=self._conv_op_data_format)
+
+ # Apply causal padding to inputs for Conv1D.
+ if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
+ inputs = array_ops.pad(inputs, self._compute_causal_padding())
+
outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias:
@@ -267,6 +295,13 @@
else:
return -1
+ def _get_input_channel(self, input_shape):
+ channel_axis = self._get_channel_axis()
+ if input_shape.dims[channel_axis].value is None:
+ raise ValueError('The channel dimension of the inputs '
+ 'should be defined. Found `None`.')
+ return int(input_shape[channel_axis])
+
def _get_padding_op(self):
if self.padding == 'causal':
op_padding = 'valid'
@@ -386,11 +421,6 @@
bias_constraint=constraints.get(bias_constraint),
**kwargs)
- def call(self, inputs):
- if self.padding == 'causal':
- inputs = array_ops.pad(inputs, self._compute_causal_padding())
- return super(Conv1D, self).call(inputs)
-
@keras_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
class Conv2D(Conv):
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index 22127b0..33e887f 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -259,6 +259,66 @@
@keras_parameterized.run_all_keras_modes
+class ConvSequentialTest(keras_parameterized.TestCase):
+
+ def _run_test(self, conv_layer_cls, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2):
+ kwargs['filters'] = 1
+ kwargs['kernel_size'] = 3
+ kwargs['dilation_rate'] = 2
+ with self.cached_session(use_gpu=True):
+ layer = conv_layer_cls(**kwargs)
+ output1 = layer(np.zeros(input_shape1))
+ self.assertEqual(output1.shape, expected_output_shape1)
+ output2 = layer(np.zeros(input_shape2))
+ self.assertEqual(output2.shape, expected_output_shape2)
+
+ @parameterized.named_parameters(
+ ('padding_valid', {'padding': 'valid'},
+ (1, 8, 2), (1, 5, 2), (1, 4, 1), (1, 1, 1)),
+ ('padding_same', {'padding': 'same'},
+ (1, 8, 2), (1, 5, 2), (1, 8, 1), (1, 5, 1)),
+ ('padding_causal', {'padding': 'causal'},
+ (1, 8, 2), (1, 5, 2), (1, 8, 1), (1, 5, 1)),
+ )
+ def test_conv1d(self, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2):
+ self._run_test(keras.layers.Conv1D, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2)
+
+ @parameterized.named_parameters(
+ ('padding_valid', {'padding': 'valid'},
+ (1, 7, 6, 2), (1, 6, 5, 2), (1, 3, 2, 1), (1, 2, 1, 1)),
+ ('padding_same', {'padding': 'same'},
+ (1, 7, 6, 2), (1, 6, 5, 2), (1, 7, 6, 1), (1, 6, 5, 1)),
+ )
+ def test_conv2d(self, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2):
+ self._run_test(keras.layers.Conv2D, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2)
+
+ @parameterized.named_parameters(
+ ('padding_valid', {'padding': 'valid'},
+ (1, 5, 7, 6, 2), (1, 8, 6, 5, 2), (1, 1, 3, 2, 1), (1, 4, 2, 1, 1)),
+ ('padding_same', {'padding': 'same'},
+ (1, 5, 7, 6, 2), (1, 8, 6, 5, 2), (1, 5, 7, 6, 1), (1, 8, 6, 5, 1)),
+ )
+ def test_conv3d(self, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2):
+ self._run_test(keras.layers.Conv3D, kwargs, input_shape1, input_shape2,
+ expected_output_shape1, expected_output_shape2)
+
+ def test_invalid_channel_dim(self):
+ with self.cached_session(use_gpu=True):
+ layer = keras.layers.Conv1D(1, 3, dilation_rate=2, padding='valid')
+ output1 = layer(np.zeros((1, 8, 2)))
+ self.assertEqual(output1.shape, (1, 4, 1))
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected input data with 2 channels'):
+ layer(np.zeros((1, 5, 3)))
+
+
+@keras_parameterized.run_all_keras_modes
class ZeroPaddingTest(keras_parameterized.TestCase):
def test_zero_padding_1d(self):