Revert D21632878: [quant] Support for fused ConvBn1d and ConvBnRelu1d modules
Test Plan: revert-hammer
Differential Revision:
D21632878
Original commit changeset: 0d73398b95d7
fbshipit-source-id: c4dd18a4220d175237f31f741a782f2596228009
diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst
index a4c685a..d8addf8 100644
--- a/docs/source/quantization.rst
+++ b/docs/source/quantization.rst
@@ -208,9 +208,7 @@
* ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with
quantized version 1 to 1:
- * :class:`~torch.nn.intrinsic.ConvBn1d` — Conv1d + BatchNorm1d
* :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm
- * :class:`~torch.nn.intrinsic.ConvBnReLU1d` — Conv1d + BatchNorm1d + ReLU
* :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
* :class:`~torch.nn.intrinsic.ConvReLU1d` — Conv1d + ReLU
* :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + ReLU
@@ -586,21 +584,11 @@
.. automodule:: torch.nn.intrinsic
-ConvBn1d
-~~~~~~~~~~~~~~~
-.. autoclass:: ConvBn1d
- :members:
-
ConvBn2d
~~~~~~~~~~~~~~~
.. autoclass:: ConvBn2d
:members:
-ConvBnReLU1d
-~~~~~~~~~~~~~~~
-.. autoclass:: ConvBnReLU1d
- :members:
-
ConvBnReLU2d
~~~~~~~~~~~~~~~
.. autoclass:: ConvBnReLU2d
diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py
index ff3ceda..c47900d 100644
--- a/test/quantization/test_quantize.py
+++ b/test/quantization/test_quantize.py
@@ -1303,25 +1303,24 @@
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data_1d)
- with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
- checkQuantized(model)
+ checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, self.img_data_1d)
- with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
- checkQuantized(model)
+ checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
- model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
+ model = fuse_modules(model, [['conv3', 'relu4'],
['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
+
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
"Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
@@ -1346,13 +1345,11 @@
"Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
- "Fused Conv + Relu for Conv1d (folded BN)")
+ "Fused Conv + Relu for conv1d")
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
- "Fused Conv + Relu for Conv1d ")
+ "Fused Conv + Relu for conv1d ")
self.assertEqual(type(model.conv3[1]), nn.ReLU,
- "Fused Conv + Relu for Conv1d")
- self.assertEqual(type(model.bn3), nn.Identity,
- "Fused Conv + BN + Relu for Conv1d (Skipped BN)")
+ "Fused Conv + Relu for conv1d")
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
"Fused submodule Conv + folded BN")
@@ -1386,7 +1383,7 @@
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
- ['conv3', 'bn3', 'relu4']])
+ ['conv3', 'relu4']])
model = quantize(model, test_only_eval_fn, self.img_data_1d)
checkQuantized(model)
diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py
index ba4514e..592917a 100644
--- a/torch/nn/intrinsic/__init__.py
+++ b/torch/nn/intrinsic/__init__.py
@@ -1,8 +1,6 @@
-from .modules import ConvBn1d
from .modules import ConvBn2d
from .modules import ConvBn3d
-from .modules import ConvBnReLU1d
from .modules import ConvBnReLU2d
from .modules import ConvBnReLU3d
from .modules import ConvReLU1d
@@ -13,11 +11,9 @@
from .modules import BNReLU3d
__all__ = [
- 'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU2d',
- 'ConvBnReLU1d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
diff --git a/torch/nn/intrinsic/modules/__init__.py b/torch/nn/intrinsic/modules/__init__.py
index 51b2c10..a79fd39 100644
--- a/torch/nn/intrinsic/modules/__init__.py
+++ b/torch/nn/intrinsic/modules/__init__.py
@@ -1,8 +1,6 @@
-from .fused import ConvBn1d
from .fused import ConvBn2d
from .fused import ConvBn3d
-from .fused import ConvBnReLU1d
from .fused import ConvBnReLU2d
from .fused import ConvBnReLU3d
from .fused import ConvReLU1d
@@ -14,10 +12,8 @@
__all__ = [
- 'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
- 'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
diff --git a/torch/nn/intrinsic/modules/fused.py b/torch/nn/intrinsic/modules/fused.py
index 71a724d..5f9cd0a 100644
--- a/torch/nn/intrinsic/modules/fused.py
+++ b/torch/nn/intrinsic/modules/fused.py
@@ -38,15 +38,6 @@
type(linear), type(relu))
super(LinearReLU, self).__init__(linear, relu)
-class ConvBn1d(torch.nn.Sequential):
- r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
- During quantization this will be replaced with the corresponding fused module."""
- def __init__(self, conv, bn):
- assert type(conv) == Conv1d and type(bn) == BatchNorm1d, \
- 'Incorrect types for input modules{}{}'.format(
- type(conv), type(bn))
- super(ConvBn1d, self).__init__(conv, bn)
-
class ConvBn2d(torch.nn.Sequential):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
@@ -56,15 +47,6 @@
type(conv), type(bn))
super(ConvBn2d, self).__init__(conv, bn)
-class ConvBnReLU1d(torch.nn.Sequential):
- r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
- During quantization this will be replaced with the corresponding fused module."""
- def __init__(self, conv, bn, relu):
- assert type(conv) == Conv1d and type(bn) == BatchNorm1d and \
- type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
- .format(type(conv), type(bn), type(relu))
- super(ConvBnReLU1d, self).__init__(conv, bn, relu)
-
class ConvBnReLU2d(torch.nn.Sequential):
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py
index 856548a..1e5cd6b 100644
--- a/torch/quantization/fuse_modules.py
+++ b/torch/quantization/fuse_modules.py
@@ -47,30 +47,19 @@
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
+ is_3d = isinstance(conv, torch.nn.Conv3d)
if conv.training:
- map_to_fused_module_train = {
- torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
- torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
- }
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
- fused_module = map_to_fused_module_train.get(type(conv))
- if fused_module is not None:
- return fused_module(conv, bn, relu)
- else:
- raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
+
+ return torch_fused.ConvBnReLU3d(conv, bn, relu) if is_3d \
+ else torch_fused.ConvBnReLU2d(conv, bn, relu)
else:
- map_to_fused_module_eval = {
- torch.nn.Conv1d: torch_fused.ConvReLU1d,
- torch.nn.Conv2d: torch_fused.ConvReLU2d,
- torch.nn.Conv3d: torch_fused.ConvReLU3d,
- }
- fused_module = map_to_fused_module_eval[type(conv)]
- if fused_module is not None:
- return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
- else:
- raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
+ return torch_fused.ConvReLU3d(
+ torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) if is_3d \
+ else torch_fused.ConvReLU2d(
+ torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
# Generalization of getattr
def _get_module(model, submodule_key):
@@ -104,8 +93,6 @@
"""
OP_LIST_TO_FUSER_METHOD = {
- (torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
- (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
(torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,
diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py
index d36a2cc..8c0315f 100644
--- a/torch/quantization/quantize.py
+++ b/torch/quantization/quantize.py
@@ -321,9 +321,7 @@
nni.LinearReLU,
nni.BNReLU2d,
nni.BNReLU3d,
- nni.ConvBn1d,
nni.ConvReLU1d,
- nni.ConvBnReLU1d,
nni.ConvReLU2d,
nni.ConvReLU3d)
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index d03ec4c..eacbdb8 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -632,7 +632,6 @@
self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
- self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
# don't quantize sub2
self.sub2.qconfig = None
@@ -642,7 +641,6 @@
x = x.squeeze(2)
x = self.quant(x)
x = self.conv3(x)
- x = self.bn3(x)
x = self.relu4(x)
x = x.unsqueeze(2)
y = x.unsqueeze(2)