Revert D21632878: [quant] Support for fused ConvBn1d and ConvBnRelu1d modules

Test Plan: revert-hammer

Differential Revision:
D21632878

Original commit changeset: 0d73398b95d7

fbshipit-source-id: c4dd18a4220d175237f31f741a782f2596228009
diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst
index a4c685a..d8addf8 100644
--- a/docs/source/quantization.rst
+++ b/docs/source/quantization.rst
@@ -208,9 +208,7 @@
 * ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with
   quantized version 1 to 1:
 
-  * :class:`~torch.nn.intrinsic.ConvBn1d` — Conv1d + BatchNorm1d
   * :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm
-  * :class:`~torch.nn.intrinsic.ConvBnReLU1d` — Conv1d + BatchNorm1d + ReLU
   * :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
   * :class:`~torch.nn.intrinsic.ConvReLU1d` — Conv1d + ReLU
   * :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + ReLU
@@ -586,21 +584,11 @@
 
 .. automodule:: torch.nn.intrinsic
 
-ConvBn1d
-~~~~~~~~~~~~~~~
-.. autoclass:: ConvBn1d
-    :members:
-
 ConvBn2d
 ~~~~~~~~~~~~~~~
 .. autoclass:: ConvBn2d
     :members:
 
-ConvBnReLU1d
-~~~~~~~~~~~~~~~
-.. autoclass:: ConvBnReLU1d
-    :members:
-
 ConvBnReLU2d
 ~~~~~~~~~~~~~~~
 .. autoclass:: ConvBnReLU2d
diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py
index ff3ceda..c47900d 100644
--- a/test/quantization/test_quantize.py
+++ b/test/quantization/test_quantize.py
@@ -1303,25 +1303,24 @@
             self.assertEqual(type(model.sub2.conv), nn.Conv2d)
             self.assertEqual(type(model.sub2.relu), nn.ReLU)
             test_only_eval_fn(model, self.img_data_1d)
-        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
-            checkQuantized(model)
+        checkQuantized(model)
 
         model = ModelForFusion(default_qat_qconfig).train()
         model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
                              ['sub1.conv', 'sub1.bn']])
         model = quantize_qat(model, test_only_train_fn, self.img_data_1d)
-        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
-            checkQuantized(model)
+        checkQuantized(model)
 
 
     def test_fuse_module_eval(self):
         model = ModelForFusion(default_qconfig)
         model.eval()
-        model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
+        model = fuse_modules(model, [['conv3', 'relu4'],
                              ['conv1', 'bn1', 'relu1'],
                              ['conv2', 'relu2'],
                              ['bn2', 'relu3'],
                              ['sub1.conv', 'sub1.bn']])
+
         self.assertEqual(type(model.conv1), nni.ConvReLU2d,
                          "Fused Conv + BN + Relu first layer (BN is folded)")
         self.assertEqual(type(model.conv1[0]), nn.Conv2d,
@@ -1346,13 +1345,11 @@
                          "Fused Conv + BN + Relu second layer (Skipped Relu)")
 
         self.assertEqual(type(model.conv3), nni.ConvReLU1d,
-                         "Fused Conv + Relu for Conv1d (folded BN)")
+                         "Fused Conv + Relu for conv1d")
         self.assertEqual(type(model.conv3[0]), nn.Conv1d,
-                         "Fused Conv + Relu for Conv1d ")
+                         "Fused Conv + Relu for conv1d ")
         self.assertEqual(type(model.conv3[1]), nn.ReLU,
-                         "Fused Conv + Relu for Conv1d")
-        self.assertEqual(type(model.bn3), nn.Identity,
-                         "Fused Conv + BN + Relu for Conv1d (Skipped BN)")
+                         "Fused Conv + Relu for conv1d")
 
         self.assertEqual(type(model.sub1.conv), nn.Conv2d,
                          "Fused submodule Conv + folded BN")
@@ -1386,7 +1383,7 @@
                              ['conv2', 'relu2'],
                              ['bn2', 'relu3'],
                              ['sub1.conv', 'sub1.bn'],
-                             ['conv3', 'bn3', 'relu4']])
+                             ['conv3', 'relu4']])
         model = quantize(model, test_only_eval_fn, self.img_data_1d)
         checkQuantized(model)
 
diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py
index ba4514e..592917a 100644
--- a/torch/nn/intrinsic/__init__.py
+++ b/torch/nn/intrinsic/__init__.py
@@ -1,8 +1,6 @@
 
-from .modules import ConvBn1d
 from .modules import ConvBn2d
 from .modules import ConvBn3d
-from .modules import ConvBnReLU1d
 from .modules import ConvBnReLU2d
 from .modules import ConvBnReLU3d
 from .modules import ConvReLU1d
@@ -13,11 +11,9 @@
 from .modules import BNReLU3d
 
 __all__ = [
-    'ConvBn1d',
     'ConvBn2d',
     'ConvBn3d',
     'ConvBnReLU2d',
-    'ConvBnReLU1d',
     'ConvBnReLU3d',
     'ConvReLU1d',
     'ConvReLU2d',
diff --git a/torch/nn/intrinsic/modules/__init__.py b/torch/nn/intrinsic/modules/__init__.py
index 51b2c10..a79fd39 100644
--- a/torch/nn/intrinsic/modules/__init__.py
+++ b/torch/nn/intrinsic/modules/__init__.py
@@ -1,8 +1,6 @@
 
-from .fused import ConvBn1d
 from .fused import ConvBn2d
 from .fused import ConvBn3d
-from .fused import ConvBnReLU1d
 from .fused import ConvBnReLU2d
 from .fused import ConvBnReLU3d
 from .fused import ConvReLU1d
@@ -14,10 +12,8 @@
 
 
 __all__ = [
-    'ConvBn1d',
     'ConvBn2d',
     'ConvBn3d',
-    'ConvBnReLU1d',
     'ConvBnReLU2d',
     'ConvBnReLU3d',
     'ConvReLU1d',
diff --git a/torch/nn/intrinsic/modules/fused.py b/torch/nn/intrinsic/modules/fused.py
index 71a724d..5f9cd0a 100644
--- a/torch/nn/intrinsic/modules/fused.py
+++ b/torch/nn/intrinsic/modules/fused.py
@@ -38,15 +38,6 @@
                 type(linear), type(relu))
         super(LinearReLU, self).__init__(linear, relu)
 
-class ConvBn1d(torch.nn.Sequential):
-    r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
-    During quantization this will be replaced with the corresponding fused module."""
-    def __init__(self, conv, bn):
-        assert type(conv) == Conv1d and type(bn) == BatchNorm1d, \
-            'Incorrect types for input modules{}{}'.format(
-                type(conv), type(bn))
-        super(ConvBn1d, self).__init__(conv, bn)
-
 class ConvBn2d(torch.nn.Sequential):
     r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
     During quantization this will be replaced with the corresponding fused module."""
@@ -56,15 +47,6 @@
                 type(conv), type(bn))
         super(ConvBn2d, self).__init__(conv, bn)
 
-class ConvBnReLU1d(torch.nn.Sequential):
-    r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
-    During quantization this will be replaced with the corresponding fused module."""
-    def __init__(self, conv, bn, relu):
-        assert type(conv) == Conv1d and type(bn) == BatchNorm1d and \
-            type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
-            .format(type(conv), type(bn), type(relu))
-        super(ConvBnReLU1d, self).__init__(conv, bn, relu)
-
 class ConvBnReLU2d(torch.nn.Sequential):
     r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
     During quantization this will be replaced with the corresponding fused module."""
diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py
index 856548a..1e5cd6b 100644
--- a/torch/quantization/fuse_modules.py
+++ b/torch/quantization/fuse_modules.py
@@ -47,30 +47,19 @@
     """
     assert(conv.training == bn.training == relu.training),\
         "Conv and BN both must be in the same mode (train or eval)."
+    is_3d = isinstance(conv, torch.nn.Conv3d)
     if conv.training:
-        map_to_fused_module_train = {
-            torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
-            torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
-        }
         assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
         assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
         assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
-        fused_module = map_to_fused_module_train.get(type(conv))
-        if fused_module is not None:
-            return fused_module(conv, bn, relu)
-        else:
-            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
+
+        return torch_fused.ConvBnReLU3d(conv, bn, relu) if is_3d \
+            else torch_fused.ConvBnReLU2d(conv, bn, relu)
     else:
-        map_to_fused_module_eval = {
-            torch.nn.Conv1d: torch_fused.ConvReLU1d,
-            torch.nn.Conv2d: torch_fused.ConvReLU2d,
-            torch.nn.Conv3d: torch_fused.ConvReLU3d,
-        }
-        fused_module = map_to_fused_module_eval[type(conv)]
-        if fused_module is not None:
-            return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
-        else:
-            raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
+        return torch_fused.ConvReLU3d(
+            torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) if is_3d \
+            else torch_fused.ConvReLU2d(
+                torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
 
 # Generalization of getattr
 def _get_module(model, submodule_key):
@@ -104,8 +93,6 @@
     """
 
     OP_LIST_TO_FUSER_METHOD = {
-        (torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
-        (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
         (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
         (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
         (torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,
diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py
index d36a2cc..8c0315f 100644
--- a/torch/quantization/quantize.py
+++ b/torch/quantization/quantize.py
@@ -321,9 +321,7 @@
                          nni.LinearReLU,
                          nni.BNReLU2d,
                          nni.BNReLU3d,
-                         nni.ConvBn1d,
                          nni.ConvReLU1d,
-                         nni.ConvBnReLU1d,
                          nni.ConvReLU2d,
                          nni.ConvReLU3d)
 
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index d03ec4c..eacbdb8 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -632,7 +632,6 @@
         self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
         self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
         self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
-        self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
         self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
         # don't quantize sub2
         self.sub2.qconfig = None
@@ -642,7 +641,6 @@
         x = x.squeeze(2)
         x = self.quant(x)
         x = self.conv3(x)
-        x = self.bn3(x)
         x = self.relu4(x)
         x = x.unsqueeze(2)
         y = x.unsqueeze(2)