Remove unnecessary Reshape layers from MobileNetV3 architecture
PiperOrigin-RevId: 375977152
Change-Id: Iaafad61b8ca12f43bc7e1cca97c6b9be508b3fae
diff --git a/tensorflow/python/keras/applications/mobilenet_v3.py b/tensorflow/python/keras/applications/mobilenet_v3.py
index 8ab59f0..a0e82fe 100644
--- a/tensorflow/python/keras/applications/mobilenet_v3.py
+++ b/tensorflow/python/keras/applications/mobilenet_v3.py
@@ -292,11 +292,7 @@
axis=channel_axis, epsilon=1e-3,
momentum=0.999, name='Conv_1/BatchNorm')(x)
x = activation(x)
- x = layers.GlobalAveragePooling2D()(x)
- if channel_axis == 1:
- x = layers.Reshape((last_conv_ch, 1, 1))(x)
- else:
- x = layers.Reshape((1, 1, last_conv_ch))(x)
+ x = layers.GlobalAveragePooling2D(keepdims=True)(x)
x = layers.Conv2D(
last_point_ch,
kernel_size=1,
@@ -462,12 +458,9 @@
def _se_block(inputs, filters, se_ratio, prefix):
- x = layers.GlobalAveragePooling2D(name=prefix + 'squeeze_excite/AvgPool')(
- inputs)
- if backend.image_data_format() == 'channels_first':
- x = layers.Reshape((filters, 1, 1))(x)
- else:
- x = layers.Reshape((1, 1, filters))(x)
+ x = layers.GlobalAveragePooling2D(
+ keepdims=True, name=prefix + 'squeeze_excite/AvgPool')(
+ inputs)
x = layers.Conv2D(
_depth(filters * se_ratio),
kernel_size=1,