Test group convolution
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index e494a3d..5fd5271 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -24,6 +24,7 @@
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -297,6 +298,25 @@
with self.assertRaisesRegexp(ValueError, 'The number of input channels'):
layer(16, 3, groups=4).build((32, 12, 12, 3))
+ @parameterized.named_parameters(
+ ('Conv1D', keras.layers.Conv1D, (32, 12, 32)),
+ ('Conv2D', keras.layers.Conv2D, (32, 12, 12, 32)),
+ ('Conv3D', keras.layers.Conv3D, (32, 12, 12, 12, 32)),
+ )
+ def test_group_conv(self, layer, input_shape):
+ if test.is_gpu_available(cuda_only=True):
+ with self.cached_session(use_gpu=True):
+ inputs = np.random.uniform(size=input_shape)
+
+ outputs = layer(16, 3, groups=4, kernel_initializer="ones")(inputs)
+
+ input_slices = np.split(inputs, 4, axis=-1)
+ expected_outputs = array_ops.concat([
+ layer(16 // 4, 3, kernel_initializer="ones")(slice)
+ for slice in input_slices], axis=-1)
+
+ self.assertAllClose(outputs, expected_outputs)
+
@keras_parameterized.run_all_keras_modes
class Conv3DTransposeTest(keras_parameterized.TestCase):