[quant][ao_migration] nn.intrinsic.qat migration to ao (#86171)

All quantization-related modules are being migrated to `torch.ao`. This migrates the `nn.intrinsic.qat`. Please, see the [tracker](https://github.com/pytorch/pytorch/issues/81667) for the timeline.

```
python test/test_quantization.py TestAOMigrationNNIntrinsic
```

Differential Revision: [D39419993](https://our.internmc.facebook.com/intern/diff/D39419993/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39419993/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86171
Approved by: https://github.com/jerryzh168
diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst
index a681e49..dfcc6f6 100644
--- a/docs/source/quantization-support.rst
+++ b/docs/source/quantization-support.rst
@@ -296,16 +296,16 @@
     BNReLU2d
     BNReLU3d
 
-torch.nn.intrinsic.qat
-~~~~~~~~~~~~~~~~~~~~~~
-.. automodule:: torch.nn.intrinsic.qat
-.. automodule:: torch.nn.intrinsic.qat.modules
+torch.ao.nn.intrinsic.qat
+~~~~~~~~~~~~~~~~~~~~~~~~~
+.. automodule:: torch.ao.nn.intrinsic.qat
+.. automodule:: torch.ao.nn.intrinsic.qat.modules
 
 
 This module implements the versions of those fused operations needed for
 quantization aware training.
 
-.. currentmodule:: torch.nn.intrinsic.qat
+.. currentmodule:: torch.ao.nn.intrinsic.qat
 
 .. autosummary::
     :toctree: generated
@@ -597,3 +597,6 @@
    :noindex:
 .. automodule:: torch.ao.nn.quantized.reference.modules
    :noindex:
+
+.. py:module:: torch.nn.intrinsic.qat
+.. py:module:: torch.nn.intrinsic.qat.modules
diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json
index d492734..02e707d 100644
--- a/test/allowlist_for_publicAPI.json
+++ b/test/allowlist_for_publicAPI.json
@@ -3,6 +3,11 @@
     "torch.nn.intrinsic": "torch.ao.nn.intrinsic",
     "torch.nn.intrinsic.modules": "torch.ao.nn.intrinsic.modules",
     "torch.nn.intrinsic.modules.fused": "torch.ao.nn.intrinsic.modules.fused",
+    "torch.nn.intrinsic.qat": "torch.ao.nn.intrinsic.qat",
+    "torch.nn.intrinsic.qat.modules": "torch.ao.nn.intrinsic.qat.modules",
+    "torch.nn.intrinsic.qat.modules.conv_fused": "torch.ao.nn.intrinsic.qat.modules.conv_fused",
+    "torch.nn.intrinsic.qat.modules.linear_fused": "torch.ao.nn.intrinsic.qat.modules.linear_fused",
+    "torch.nn.intrinsic.qat.modules.linear_relu": "torch.ao.nn.intrinsic.qat.modules.linear_relu",
     "torch.nn.qat": "torch.ao.nn.qat",
     "torch.nn.qat.dynamic": "torch.ao.nn.qat.dynamic",
     "torch.nn.qat.dynamic.modules": "torch.ao.nn.qat.dynamic.modules",
diff --git a/test/ao/sparsity/test_composability.py b/test/ao/sparsity/test_composability.py
index 366116e..f531dd2 100644
--- a/test/ao/sparsity/test_composability.py
+++ b/test/ao/sparsity/test_composability.py
@@ -514,7 +514,7 @@
         # that none were lost during prepare
         self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
         self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
-        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU))
+        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
 
         # check that correct observers were inserted and that matching
         # occured successfully
diff --git a/test/quantization/ao_migration/test_ao_migration.py b/test/quantization/ao_migration/test_ao_migration.py
index 6b76a01..3d0e3b3 100644
--- a/test/quantization/ao_migration/test_ao_migration.py
+++ b/test/quantization/ao_migration/test_ao_migration.py
@@ -382,7 +382,6 @@
 
     def test_package_import_nn_intrinsic(self):
         skip = [
-            'qat',
             'quantized',
         ]
         self._test_package_import('intrinsic', base='nn', skip=skip)
@@ -407,7 +406,7 @@
         ]
         self._test_function_import('intrinsic', module_list, base='nn')
 
-    def test_modules_fused(self):
+    def test_modules_nn_intrinsic_fused(self):
         function_list = [
             '_FusedModule',
             'ConvBn1d',
@@ -426,3 +425,57 @@
         ]
         self._test_function_import('fused', function_list,
                                    base='nn.intrinsic.modules')
+
+    def test_package_import_nn_intrinsic_qat(self):
+        r"""Tests the migration of the torch.nn.intrinsic.modules"""
+        self._test_package_import('qat', base='nn.intrinsic')
+        self._test_package_import('qat.modules', base='nn.intrinsic')
+
+    def test_modules_import_nn_intrinsic_qat(self):
+        module_list = [
+            "LinearReLU",
+            "LinearBn1d",
+            "ConvReLU1d",
+            "ConvReLU2d",
+            "ConvReLU3d",
+            "ConvBn1d",
+            "ConvBn2d",
+            "ConvBn3d",
+            "ConvBnReLU1d",
+            "ConvBnReLU2d",
+            "ConvBnReLU3d",
+            "update_bn_stats",
+            "freeze_bn_stats",
+        ]
+        self._test_function_import('qat', module_list, base='nn.intrinsic')
+
+    def test_modules_intrinsic_qat_conv_fused(self):
+        function_list = [
+            'ConvBn1d',
+            'ConvBnReLU1d',
+            'ConvReLU1d',
+            'ConvBn2d',
+            'ConvBnReLU2d',
+            'ConvReLU2d',
+            'ConvBn3d',
+            'ConvBnReLU3d',
+            'ConvReLU3d',
+            'update_bn_stats',
+            'freeze_bn_stats'
+        ]
+        self._test_function_import('conv_fused', function_list,
+                                   base='nn.intrinsic.qat.modules')
+
+    def test_modules_intrinsic_qat_linear_fused(self):
+        function_list = [
+            'LinearBn1d',
+        ]
+        self._test_function_import('linear_fused', function_list,
+                                   base='nn.intrinsic.qat.modules')
+
+    def test_modules_intrinsic_qat_linear_relu(self):
+        function_list = [
+            'LinearReLU',
+        ]
+        self._test_function_import('linear_relu', function_list,
+                                   base='nn.intrinsic.qat.modules')
diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py
index ba93a79..5b46e6a 100644
--- a/test/quantization/core/test_workflow_module.py
+++ b/test/quantization/core/test_workflow_module.py
@@ -864,7 +864,7 @@
                 if epoch >= 1:
                     model.apply(torch.ao.quantization.disable_observer)
                 if epoch >= 2:
-                    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+                    model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
                 quant_model = copy.deepcopy(model.module)
                 quant_model = torch.ao.quantization.convert(quant_model.eval().cpu(), inplace=False)
                 with torch.no_grad():
diff --git a/test/quantization/eager/test_fuse_eager.py b/test/quantization/eager/test_fuse_eager.py
index e397fa9..9f120b8 100644
--- a/test/quantization/eager/test_fuse_eager.py
+++ b/test/quantization/eager/test_fuse_eager.py
@@ -7,7 +7,7 @@
 import torch.ao.nn.quantized as nnq
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 from torch.ao.quantization import (
     quantize,
     prepare,
diff --git a/test/quantization/eager/test_model_numerics.py b/test/quantization/eager/test_model_numerics.py
index b259e10..bcefb78 100644
--- a/test/quantization/eager/test_model_numerics.py
+++ b/test/quantization/eager/test_model_numerics.py
@@ -73,7 +73,7 @@
                 torch.ao.quantization.prepare_qat(fq_model)
                 fq_model.eval()
                 fq_model.apply(torch.ao.quantization.disable_fake_quant)
-                fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+                fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
                 fq_model(calib_data)
                 fq_model.apply(torch.ao.quantization.enable_fake_quant)
                 fq_model.apply(torch.ao.quantization.disable_observer)
@@ -109,7 +109,7 @@
                     torch.ao.quantization.prepare_qat(fq_model)
                     fq_model.eval()
                     fq_model.apply(torch.ao.quantization.disable_fake_quant)
-                    fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+                    fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
                     fq_model(calib_data)
                     fq_model.apply(torch.ao.quantization.enable_fake_quant)
                     fq_model.apply(torch.ao.quantization.disable_observer)
diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py
index c625425..bc118a8 100644
--- a/test/quantization/eager/test_quantize_eager_qat.py
+++ b/test/quantization/eager/test_quantize_eager_qat.py
@@ -6,12 +6,12 @@
 import torch.nn as nn
 import torch.backends.mkldnn
 from torch.nn import Conv2d, BatchNorm2d, ReLU, init
-from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
+from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
 from torch.nn.modules.utils import _pair
 import torch.ao.nn.quantized as nnq
 import torch.ao.nn.quantized.dynamic as nnqd
 import torch.ao.nn.qat as nnqat
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 import torch.ao.nn.qat.dynamic as nnqatd
 from torch.ao.quantization import (
     prepare,
@@ -819,9 +819,9 @@
 
         qat_op.apply(torch.ao.quantization.disable_fake_quant)
         if freeze_bn:
-            qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+            qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
         else:
-            qat_op.apply(torch.nn.intrinsic.qat.update_bn_stats)
+            qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
 
         # align inputs and internal parameters
         input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
@@ -992,7 +992,7 @@
             input_clone = input.clone().detach().requires_grad_()
 
             if i > 2:
-                qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+                qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
                 qat_ref_op.freeze_bn_stats()
 
             if i > 3:
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 595f265..4a49a07 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -746,7 +746,7 @@
             ],
         }
         prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
-        self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.nn.intrinsic.qat.LinearReLU))
+        self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU))
 
     def _get_conv_linear_test_cases(self, is_reference):
         """ Returns a list of test cases, with format:
diff --git a/torch/ao/nn/intrinsic/__init__.py b/torch/ao/nn/intrinsic/__init__.py
index c919966..ef2582b 100644
--- a/torch/ao/nn/intrinsic/__init__.py
+++ b/torch/ao/nn/intrinsic/__init__.py
@@ -1,6 +1,9 @@
 from .modules import *  # noqa: F403
 from .modules.fused import _FusedModule  # noqa: F403
 
+# Subpackages
+from . import qat  # noqa: F403
+
 __all__ = [
     'ConvBn1d',
     'ConvBn2d',
diff --git a/torch/ao/nn/intrinsic/qat/__init__.py b/torch/ao/nn/intrinsic/qat/__init__.py
new file mode 100644
index 0000000..3d79bdb
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/torch/ao/nn/intrinsic/qat/modules/__init__.py b/torch/ao/nn/intrinsic/qat/modules/__init__.py
new file mode 100644
index 0000000..f44820c
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/__init__.py
@@ -0,0 +1,31 @@
+from .linear_relu import LinearReLU
+from .linear_fused import LinearBn1d
+from .conv_fused import (
+    ConvBn1d,
+    ConvBn2d,
+    ConvBn3d,
+    ConvBnReLU1d,
+    ConvBnReLU2d,
+    ConvBnReLU3d,
+    ConvReLU1d,
+    ConvReLU2d,
+    ConvReLU3d,
+    update_bn_stats,
+    freeze_bn_stats,
+)
+
+__all__ = [
+    "LinearReLU",
+    "LinearBn1d",
+    "ConvReLU1d",
+    "ConvReLU2d",
+    "ConvReLU3d",
+    "ConvBn1d",
+    "ConvBn2d",
+    "ConvBn3d",
+    "ConvBnReLU1d",
+    "ConvBnReLU2d",
+    "ConvBnReLU3d",
+    "update_bn_stats",
+    "freeze_bn_stats",
+]
diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
new file mode 100644
index 0000000..6a6f4c1
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
@@ -0,0 +1,828 @@
+import math
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.qat as nnqat
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.utils import fuse_conv_bn_weights
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.nn.parameter import Parameter
+from typing import TypeVar
+
+__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
+           'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
+_BN_CLASS_MAP = {
+    1: nn.BatchNorm1d,
+    2: nn.BatchNorm2d,
+    3: nn.BatchNorm3d,
+}
+
+
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
+
+
+class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
+
+    _version = 2
+    _FLOAT_MODULE = MOD
+
+    def __init__(self,
+                 # ConvNd args
+                 in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, transposed, output_padding,
+                 groups,
+                 bias,
+                 padding_mode,
+                 # BatchNormNd args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None,
+                 dim=2):
+        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
+                                         stride, padding, dilation, transposed,
+                                         output_padding, groups, False, padding_mode)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.freeze_bn = freeze_bn if self.training else True
+        self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
+        self.weight_fake_quant = self.qconfig.weight()
+        if bias:
+            self.bias = Parameter(torch.empty(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_bn_parameters()
+
+        # this needs to be called after reset_bn_parameters,
+        # as they modify the same state
+        if self.training:
+            if freeze_bn:
+                self.freeze_bn_stats()
+            else:
+                self.update_bn_stats()
+        else:
+            self.freeze_bn_stats()
+
+        self._enable_slow_path_for_better_numerical_stability = False
+
+    def reset_running_stats(self):
+        self.bn.reset_running_stats()
+
+    def reset_bn_parameters(self):
+        self.bn.reset_running_stats()
+        init.uniform_(self.bn.weight)
+        init.zeros_(self.bn.bias)
+        # note: below is actully for conv, not BN
+        if self.bias is not None:
+            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+            bound = 1 / math.sqrt(fan_in)
+            init.uniform_(self.bias, -bound, bound)
+
+    def reset_parameters(self):
+        super(_ConvBnNd, self).reset_parameters()
+
+    def update_bn_stats(self):
+        self.freeze_bn = False
+        self.bn.training = True
+        return self
+
+    def freeze_bn_stats(self):
+        self.freeze_bn = True
+        self.bn.training = False
+        return self
+
+    def _forward(self, input):
+        if self._enable_slow_path_for_better_numerical_stability:
+            return self._forward_slow(input)
+        return self._forward_approximate(input)
+
+    def _forward_approximate(self, input):
+        """Approximated method to fuse conv and bn. It requires only one forward pass.
+        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
+        """
+        assert self.bn.running_var is not None
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+        # using zero bias here since the bias for original conv
+        # will be added later
+        if self.bias is not None:
+            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
+        else:
+            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
+        conv = self._conv_forward(input, scaled_weight, zero_bias)
+        conv_orig = conv / scale_factor.reshape(bias_shape)
+        if self.bias is not None:
+            conv_orig = conv_orig + self.bias.reshape(bias_shape)
+        conv = self.bn(conv_orig)
+        return conv
+
+    def _forward_slow(self, input):
+        """
+        A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
+        It requires two forward passes but handles the case bn.weight == 0
+
+        Conv: Y = WX + B_c
+        Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
+
+        Batch statistics:
+          mean_Y = Y.mean()
+                 = Y0.mean() + B_c
+          var_Y = (Y - mean_Y)^2.mean()
+                = (Y0 - Y0.mean())^2.mean()
+        BN (r: bn.weight, beta: bn.bias):
+          Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
+            = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
+
+        Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
+          Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
+            = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
+
+        Fused Conv BN inference (running_std = sqrt(running_var + eps)):
+          Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
+
+        QAT with fused conv bn:
+          Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+                  = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+          Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
+        """
+
+        assert self.bn.running_var is not None
+        assert self.bn.running_mean is not None
+
+        # using zero bias here since the bias for original conv
+        # will be added later
+        zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.dtype)
+
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+
+        if self.bn.training:
+            # needed to compute batch mean/std
+            conv_out = self._conv_forward(input, self.weight, zero_bias)
+            # update bn statistics
+            with torch.no_grad():
+                conv_out_bias = (
+                    conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
+                )
+                self.bn(conv_out_bias)
+
+        # fused conv + bn without bias using bn running statistics
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        scaled_weight = self.weight_fake_quant(
+            self.weight * scale_factor.reshape(weight_shape)
+        )
+        # fused conv without bias for inference: (r * W / running_std) * X
+        conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
+
+        if self.bn.training:
+            avg_dims = [0] + list(range(2, len(self.weight.shape)))
+            batch_mean = conv_out.mean(avg_dims)
+            batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
+                avg_dims
+            )
+            batch_std = torch.sqrt(batch_var + self.bn.eps)
+
+            # scale to use batch std in training mode
+            # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
+            unscale_factor = running_std / batch_std
+            conv_bn *= unscale_factor.reshape(bias_shape)
+
+            fused_mean = batch_mean
+            fused_std = batch_std
+        else:
+            fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
+            fused_std = running_std
+
+        # fused bias = beta - r * mean / std
+        fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
+        conv_bn += fused_bias.reshape(bias_shape)
+
+        # HACK to let conv bias particpiate in loss to avoid DDP error (parameters
+        #   were not used in producing loss)
+        if self.bias is not None:
+            conv_bn += (self.bias - self.bias).reshape(bias_shape)
+
+        return conv_bn
+
+    def extra_repr(self):
+        # TODO(jerryzh): extend
+        return super(_ConvBnNd, self).extra_repr()
+
+    def forward(self, input):
+        return self._forward(input)
+
+    def train(self, mode=True):
+        """
+        Batchnorm's training behavior is using the self.training flag. Prevent
+        changing it if BN is frozen. This makes sure that calling `model.train()`
+        on a model with a frozen BN will behave properly.
+        """
+        self.training = mode
+        if not self.freeze_bn:
+            for module in self.children():
+                module.train(mode)
+        return self
+
+    # ===== Serialization version history =====
+    #
+    # Version 1/None
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- gamma : Tensor
+    #   |--- beta : Tensor
+    #   |--- running_mean : Tensor
+    #   |--- running_var : Tensor
+    #   |--- num_batches_tracked : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- bn : Module
+    #        |--- weight : Tensor (moved from v1.self.gamma)
+    #        |--- bias : Tensor (moved from v1.self.beta)
+    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
+    #        |--- running_var : Tensor (moved from v1.self.running_var)
+    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        if version is None or version == 1:
+            # BN related parameters and buffers were moved into the BN module for v2
+            v2_to_v1_names = {
+                'bn.weight': 'gamma',
+                'bn.bias': 'beta',
+                'bn.running_mean': 'running_mean',
+                'bn.running_var': 'running_var',
+                'bn.num_batches_tracked': 'num_batches_tracked',
+            }
+            for v2_name, v1_name in v2_to_v1_names.items():
+                if prefix + v1_name in state_dict:
+                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
+                    state_dict.pop(prefix + v1_name)
+                elif prefix + v2_name in state_dict:
+                    # there was a brief period where forward compatibility
+                    # for this module was broken (between
+                    # https://github.com/pytorch/pytorch/pull/38478
+                    # and https://github.com/pytorch/pytorch/pull/38820)
+                    # and modules emitted the v2 state_dict format while
+                    # specifying that version == 1. This patches the forward
+                    # compatibility issue by allowing the v2 style entries to
+                    # be used.
+                    pass
+                elif strict:
+                    missing_keys.append(prefix + v2_name)
+
+        super(_ConvBnNd, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module or qparams_dict
+
+            Args: `mod` a float module, either produced by torch.ao.quantization utilities
+            or directly from user
+        """
+        # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
+        # has no __name__ (code is fine though)
+        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
+            cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid qconfig'
+        qconfig = mod.qconfig
+        conv, bn = mod[0], mod[1]
+        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
+                         conv.stride, conv.padding, conv.dilation,
+                         conv.groups, conv.bias is not None,
+                         conv.padding_mode,
+                         bn.eps, bn.momentum,
+                         False,
+                         qconfig)
+        qat_convbn.weight = conv.weight
+        qat_convbn.bias = conv.bias
+        qat_convbn.bn.weight = bn.weight
+        qat_convbn.bn.bias = bn.bias
+        qat_convbn.bn.running_mean = bn.running_mean
+        qat_convbn.bn.running_var = bn.running_var
+        # mypy error: Cannot determine type of 'num_batches_tracked'
+        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked  # type: ignore[has-type]
+        return qat_convbn
+
+    def to_float(self):
+        cls = type(self)
+        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
+            self.in_channels,
+            self.out_channels,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+            self.bias is not None,
+            self.padding_mode)
+        conv.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            conv.bias = torch.nn.Parameter(self.bias.detach())
+
+        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
+            # fuse bn into conv
+            conv.weight, conv.bias = fuse_conv_bn_weights(
+                conv.weight,
+                conv.bias,
+                self.bn.running_mean,
+                self.bn.running_var,
+                self.bn.eps,
+                self.bn.weight,
+                self.bn.bias
+            )
+
+        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
+            modules = []
+            modules.append(conv)
+            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
+            modules.append(relu)
+            conv_relu = cls._FUSED_FLOAT_MODULE(*modules)  # type: ignore[attr-defined]
+            conv_relu.train(self.training)
+            return conv_relu
+        else:
+            conv.train(self.training)
+            return conv
+
+class ConvBn1d(_ConvBnNd, nn.Conv1d):
+    r"""
+    A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv1d` and
+    :class:`torch.nn.BatchNorm1d`.
+
+    Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_BN_MODULE = nn.BatchNorm1d
+    _FLOAT_RELU_MODULE = None
+    _FLOAT_MODULE = nni.ConvBn1d
+    _FLOAT_CONV_MODULE = nn.Conv1d
+
+    def __init__(self,
+                 # Conv1d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm1d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        kernel_size = _single(kernel_size)
+        stride = _single(stride)
+        padding = _single(padding)
+        dilation = _single(dilation)
+        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+                           padding, dilation, False, _single(0), groups, bias, padding_mode,
+                           eps, momentum, freeze_bn, qconfig, dim=1)
+
+class ConvBnReLU1d(ConvBn1d):
+    r"""
+    A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv1d` and
+    :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    # base class defines _FLOAT_MODULE as "ConvBn1d"
+    _FLOAT_MODULE = nni.ConvBnReLU1d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv1d
+    _FLOAT_BN_MODULE = nn.BatchNorm1d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU1d
+
+    def __init__(self,
+                 # Conv1d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm1d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias,
+                         padding_mode, eps, momentum,
+                         freeze_bn,
+                         qconfig)
+
+    def forward(self, input):
+        return F.relu(ConvBn1d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvBnReLU1d, cls).from_float(mod)
+
+class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
+    r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv1d` and
+    :class:`~torch.nn.BatchNorm1d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU1d
+    _FLOAT_CONV_MODULE = nn.Conv1d
+    _FLOAT_BN_MODULE = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=True, padding_mode='zeros',
+                 qconfig=None):
+        super(ConvReLU1d, self).__init__(in_channels, out_channels, kernel_size,
+                                         stride=stride, padding=padding, dilation=dilation,
+                                         groups=groups, bias=bias, padding_mode=padding_mode,
+                                         qconfig=qconfig)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvReLU1d, cls).from_float(mod)
+
+class ConvBn2d(_ConvBnNd, nn.Conv2d):
+    r"""
+    A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv2d` and
+    :class:`torch.nn.BatchNorm2d`.
+
+    Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBn2d
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE = nn.BatchNorm2d
+    _FLOAT_RELU_MODULE = None
+
+    def __init__(self,
+                 # ConvNd args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm2d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        kernel_size = _pair(kernel_size)
+        stride = _pair(stride)
+        padding = _pair(padding)
+        dilation = _pair(dilation)
+        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+                           padding, dilation, False, _pair(0), groups, bias, padding_mode,
+                           eps, momentum, freeze_bn, qconfig, dim=2)
+
+class ConvBnReLU2d(ConvBn2d):
+    r"""
+    A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv2d` and
+    :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    # base class defines _FLOAT_MODULE as "ConvBn2d"
+    _FLOAT_MODULE = nni.ConvBnReLU2d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE = nn.BatchNorm2d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU2d
+
+    def __init__(self,
+                 # Conv2d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm2d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
+                                           padding, dilation, groups, bias,
+                                           padding_mode, eps, momentum,
+                                           freeze_bn,
+                                           qconfig)
+
+    def forward(self, input):
+        return F.relu(ConvBn2d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvBnReLU2d, cls).from_float(mod)
+
+class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
+    r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv2d` and
+    :class:`~torch.nn.BatchNorm2d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU2d
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=True, padding_mode='zeros',
+                 qconfig=None):
+        super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
+                                         stride=stride, padding=padding, dilation=dilation,
+                                         groups=groups, bias=bias, padding_mode=padding_mode,
+                                         qconfig=qconfig)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvReLU2d, cls).from_float(mod)
+
+class ConvBn3d(_ConvBnNd, nn.Conv3d):
+    r"""
+    A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv3d` and
+    :class:`torch.nn.BatchNorm3d`.
+
+    Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBn3d
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE = nn.BatchNorm3d
+    _FLOAT_RELU_MODULE = None
+
+    def __init__(
+        self,
+        # ConvNd args
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=None,
+        padding_mode="zeros",
+        # BatchNorm3d args
+        # num_features: out_channels
+        eps=1e-05,
+        momentum=0.1,
+        # affine: True
+        # track_running_stats: True
+        # Args for this module
+        freeze_bn=False,
+        qconfig=None,
+    ):
+        kernel_size = _triple(kernel_size)
+        stride = _triple(stride)
+        padding = _triple(padding)
+        dilation = _triple(dilation)
+        _ConvBnNd.__init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            False,
+            _triple(0),
+            groups,
+            bias,
+            padding_mode,
+            eps,
+            momentum,
+            freeze_bn,
+            qconfig,
+            dim=3,
+        )
+
+class ConvBnReLU3d(ConvBn3d):
+    r"""
+    A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv3d` and
+    :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBnReLU3d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE = nn.BatchNorm3d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU3d
+
+    def __init__(
+        self,
+        # Conv3d args
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=None,
+        padding_mode="zeros",
+        # BatchNorm3d args
+        # num_features: out_channels
+        eps=1e-05,
+        momentum=0.1,
+        # affine: True
+        # track_running_stats: True
+        # Args for this module
+        freeze_bn=False,
+        qconfig=None,
+    ):
+        super(ConvBnReLU3d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            groups,
+            bias,
+            padding_mode,
+            eps,
+            momentum,
+            freeze_bn,
+            qconfig,
+        )
+
+    def forward(self, input):
+        return F.relu(ConvBn3d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvBnReLU3d, cls).from_float(mod)
+
+class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
+    r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv3d` and
+    :class:`~torch.nn.BatchNorm3d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU3d
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=True,
+        padding_mode="zeros",
+        qconfig=None,
+    ):
+        super(ConvReLU3d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+            padding_mode=padding_mode,
+            qconfig=qconfig,
+        )
+        assert qconfig, "qconfig must be provided for QAT module"
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+        )
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(ConvReLU3d, cls).from_float(mod)
+
+def update_bn_stats(mod):
+    if type(mod) in set(
+        [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
+    ):
+        mod.update_bn_stats()
+
+def freeze_bn_stats(mod):
+    if type(mod) in set(
+        [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
+    ):
+        mod.freeze_bn_stats()
diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
new file mode 100644
index 0000000..f19dbd9
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
@@ -0,0 +1,167 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.parameter import Parameter
+from torch.nn.utils.fusion import fuse_linear_bn_weights
+
+
+class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
+    r"""
+    A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
+    with FakeQuantize modules for weight, used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Linear` and
+    :class:torch.nn.BatchNorm1d`.
+
+    Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    def __init__(self,
+                 # Linear args
+                 in_features, out_features, bias=True,
+                 # BatchNorm1d args
+                 # num_features: out_features
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
+        assert qconfig, 'qconfig must be provded for QAT module'
+        self.qconfig = qconfig
+        self.freeze_bn = freeze_bn if self.training else True
+        self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
+        self.weight_fake_quant = self.qconfig.weight()
+        if bias:
+            self.bias = Parameter(torch.empty(out_features))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_bn_parameters()
+
+        # this needs to be called after reset_bn_parameters,
+        # as they modify the same state
+        if self.training:
+            if freeze_bn:
+                self.freeze_bn_stats()
+            else:
+                self.update_bn_stats()
+        else:
+            self.freeze_bn_stats()
+
+    def reset_running_stats(self):
+        self.bn.reset_running_stats()
+
+    def reset_bn_parameters(self):
+        self.bn.reset_running_stats()
+        init.uniform_(self.bn.weight)
+        init.zeros_(self.bn.bias)
+
+    def reset_parameters(self):
+        super(LinearBn1d, self).reset_parameters()
+
+    def update_bn_stats(self):
+        self.freeze_bn = False
+        self.bn.training = True
+        return self
+
+    def freeze_bn_stats(self):
+        self.freeze_bn = True
+        self.bn.training = False
+        return self
+
+    def forward(self, input):
+        assert self.bn.running_var is not None
+
+        # Scale the linear weights by BN's running statistics to reduce
+        # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
+        # for motivation.
+        #
+        # Instead of
+        #
+        #   x1 = F.linear(x0, fq(w), b)
+        #   x2 = self.bn(x1)
+        #
+        # We have
+        #
+        #   # scale the weight by previous batch's running statistics
+        #   scale_factor = bn.w / bn.running_std_from_prev_batch
+        #   # do the linear transformation without bias
+        #   x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
+        #   # reverse the scaling and add original bias
+        #   x1_orig = x1_scaled / scale_factor + b
+        #   x2 = self.bn(x1_orig)
+
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+        if self.bias is not None:
+            zero_bias = torch.zeros_like(self.bias)
+        else:
+            zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
+        linear_out = F.linear(input, scaled_weight, zero_bias)
+        linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
+        if self.bias is not None:
+            linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
+        bn_out = self.bn(linear_out_orig)
+        return bn_out
+
+    def train(self, mode=True):
+        """
+        Batchnorm's training behavior is using the self.training flag. Prevent
+        changing it if BN is frozen. This makes sure that calling `model.train()`
+        on a model with a frozen BN will behave properly.
+        """
+        self.training = mode
+        if not self.freeze_bn:
+            for module in self.children():
+                module.train(mode)
+        return self
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module or qparams_dict
+
+            Args: `mod' a float module, either produced by torch.ao.quantization
+            utilities or directly from user
+        """
+        assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
+            '.from_float only works for ' + nni.LinearBn1d.__name__
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid config'
+        qconfig = mod.qconfig
+        linear, bn = mod[0], mod[1]
+        qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
+                           bn.eps, bn.momentum,
+                           False, qconfig)
+        qat_linearbn.weight = linear.weight
+        qat_linearbn.bias = linear.bias
+        qat_linearbn.bn.weight = bn.weight
+        qat_linearbn.bn.bias = bn.bias
+        qat_linearbn.bn.running_mean = bn.running_mean
+        qat_linearbn.bn.running_var = bn.running_var
+        qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
+        return qat_linearbn
+
+    def to_float(self):
+        linear = torch.nn.Linear(self.in_features, self.out_features)
+        linear.weight, linear.bias = fuse_linear_bn_weights(
+            self.weight,
+            self.bias,
+            self.bn.running_mean,
+            self.bn.running_var,
+            self.bn.eps,
+            self.bn.weight,
+            self.bn.bias)
+        return linear
diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
new file mode 100644
index 0000000..1c77965
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
@@ -0,0 +1,48 @@
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+
+class LinearReLU(nnqat.Linear, nni._FusedModule):
+    r"""
+    A LinearReLU module fused from Linear and ReLU modules, attached with
+    FakeQuantize modules for weight, used in
+    quantization aware training.
+
+    We adopt the same interface as :class:`torch.nn.Linear`.
+
+    Similar to `torch.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight: fake quant module for weight
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.qat.LinearReLU(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearReLU
+
+    def __init__(self, in_features, out_features, bias=True,
+                 qconfig=None):
+        super(LinearReLU, self).__init__(in_features, out_features, bias, qconfig)
+
+    def forward(self, input):
+        return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(LinearReLU, cls).from_float(mod)
+
+    def to_float(self):
+        linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
+        linear.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            linear.bias = torch.nn.Parameter(self.bias.detach())
+        relu = torch.nn.ReLU()
+        return torch.nn.intrinsic.LinearReLU(linear, relu)
diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py
index c0966e0..234048d 100644
--- a/torch/ao/nn/quantized/modules/conv.py
+++ b/torch/ao/nn/quantized/modules/conv.py
@@ -7,7 +7,7 @@
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 
 from torch._ops import ops
 from torch.nn.common_types import _size_1_t
diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py
index 85b0f8e..77adcd1 100644
--- a/torch/ao/nn/quantized/modules/linear.py
+++ b/torch/ao/nn/quantized/modules/linear.py
@@ -3,7 +3,7 @@
 
 import torch.nn as nn
 import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 from torch.nn.utils.fusion import fuse_linear_bn_weights
 from torch.nn.utils.parametrize import type_before_parametrizations
 
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index 7e9ffcc..321ec09 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -9,7 +9,7 @@
 import torch.ao.nn.quantized.dynamic as nnqd
 import torch.nn.intrinsic.quantized as nniq
 import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 import torch.nn.intrinsic as nni
 import torch.ao.nn.qat as nnqat
 import torch.ao.nn.qat.dynamic as nnqatd
diff --git a/torch/ao/ns/fx/weight_utils.py b/torch/ao/ns/fx/weight_utils.py
index 553e385..e02d464 100644
--- a/torch/ao/ns/fx/weight_utils.py
+++ b/torch/ao/ns/fx/weight_utils.py
@@ -3,7 +3,7 @@
 import torch.nn.functional as F
 import torch.ao.nn.quantized.dynamic as nnqd
 import torch.ao.nn.quantized as nnq
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 import torch.ao.nn.qat as nnqat
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index baf4686..bc6f678 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -3,7 +3,7 @@
 import torch.nn.functional as F
 import torch.nn as nn
 import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 import torch.nn.qat as nnqat
 import torch.nn.quantized._reference as nnqr
 from collections import namedtuple
diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py
index 15d291b..5623d78 100644
--- a/torch/ao/quantization/quantization_mappings.py
+++ b/torch/ao/quantization/quantization_mappings.py
@@ -7,7 +7,7 @@
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
 import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
 import torch.ao.nn.quantized as nnq
 import torch.ao.nn.quantized.reference as nnqr
 import torch.ao.nn.quantized.dynamic as nnqd
diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py
index f0be587..fbc89c1 100644
--- a/torch/nn/intrinsic/__init__.py
+++ b/torch/nn/intrinsic/__init__.py
@@ -13,8 +13,9 @@
 from torch.ao.nn.intrinsic import LinearBn1d
 from torch.ao.nn.intrinsic.modules.fused import _FusedModule  # noqa: F401
 
-# Include the `module` in case user imports from it directly
-from . import modules
+# Include the subpackages in case user imports from it directly
+from . import modules  # noqa: F401
+from . import qat  # noqa: F401
 
 __all__ = [
     'ConvBn1d',
diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py
index 6a6f4c1..ccd79bb 100644
--- a/torch/nn/intrinsic/qat/modules/conv_fused.py
+++ b/torch/nn/intrinsic/qat/modules/conv_fused.py
@@ -1,828 +1,37 @@
-import math
-import torch
-import torch.nn as nn
-import torch.ao.nn.intrinsic as nni
-import torch.ao.nn.qat as nnqat
-import torch.nn.functional as F
-from torch.nn import init
-from torch.nn.utils import fuse_conv_bn_weights
-from torch.nn.modules.utils import _single, _pair, _triple
-from torch.nn.parameter import Parameter
-from typing import TypeVar
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
 
-__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
-           'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
-_BN_CLASS_MAP = {
-    1: nn.BatchNorm1d,
-    2: nn.BatchNorm2d,
-    3: nn.BatchNorm3d,
-}
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
 
+__all__ = [
+    # Modules
+    'ConvBn1d',
+    'ConvBnReLU1d',
+    'ConvReLU1d',
+    'ConvBn2d',
+    'ConvBnReLU2d',
+    'ConvReLU2d',
+    'ConvBn3d',
+    'ConvBnReLU3d',
+    'ConvReLU3d',
+    # Utilities
+    'freeze_bn_stats',
+    'update_bn_stats',
+]
 
-MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
-
-
-class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
-
-    _version = 2
-    _FLOAT_MODULE = MOD
-
-    def __init__(self,
-                 # ConvNd args
-                 in_channels, out_channels, kernel_size, stride,
-                 padding, dilation, transposed, output_padding,
-                 groups,
-                 bias,
-                 padding_mode,
-                 # BatchNormNd args
-                 # num_features: out_channels
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None,
-                 dim=2):
-        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
-                                         stride, padding, dilation, transposed,
-                                         output_padding, groups, False, padding_mode)
-        assert qconfig, 'qconfig must be provided for QAT module'
-        self.qconfig = qconfig
-        self.freeze_bn = freeze_bn if self.training else True
-        self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
-        self.weight_fake_quant = self.qconfig.weight()
-        if bias:
-            self.bias = Parameter(torch.empty(out_channels))
-        else:
-            self.register_parameter('bias', None)
-        self.reset_bn_parameters()
-
-        # this needs to be called after reset_bn_parameters,
-        # as they modify the same state
-        if self.training:
-            if freeze_bn:
-                self.freeze_bn_stats()
-            else:
-                self.update_bn_stats()
-        else:
-            self.freeze_bn_stats()
-
-        self._enable_slow_path_for_better_numerical_stability = False
-
-    def reset_running_stats(self):
-        self.bn.reset_running_stats()
-
-    def reset_bn_parameters(self):
-        self.bn.reset_running_stats()
-        init.uniform_(self.bn.weight)
-        init.zeros_(self.bn.bias)
-        # note: below is actully for conv, not BN
-        if self.bias is not None:
-            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
-            bound = 1 / math.sqrt(fan_in)
-            init.uniform_(self.bias, -bound, bound)
-
-    def reset_parameters(self):
-        super(_ConvBnNd, self).reset_parameters()
-
-    def update_bn_stats(self):
-        self.freeze_bn = False
-        self.bn.training = True
-        return self
-
-    def freeze_bn_stats(self):
-        self.freeze_bn = True
-        self.bn.training = False
-        return self
-
-    def _forward(self, input):
-        if self._enable_slow_path_for_better_numerical_stability:
-            return self._forward_slow(input)
-        return self._forward_approximate(input)
-
-    def _forward_approximate(self, input):
-        """Approximated method to fuse conv and bn. It requires only one forward pass.
-        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
-        """
-        assert self.bn.running_var is not None
-        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
-        scale_factor = self.bn.weight / running_std
-        weight_shape = [1] * len(self.weight.shape)
-        weight_shape[0] = -1
-        bias_shape = [1] * len(self.weight.shape)
-        bias_shape[1] = -1
-        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
-        # using zero bias here since the bias for original conv
-        # will be added later
-        if self.bias is not None:
-            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
-        else:
-            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
-        conv = self._conv_forward(input, scaled_weight, zero_bias)
-        conv_orig = conv / scale_factor.reshape(bias_shape)
-        if self.bias is not None:
-            conv_orig = conv_orig + self.bias.reshape(bias_shape)
-        conv = self.bn(conv_orig)
-        return conv
-
-    def _forward_slow(self, input):
-        """
-        A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
-        It requires two forward passes but handles the case bn.weight == 0
-
-        Conv: Y = WX + B_c
-        Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
-
-        Batch statistics:
-          mean_Y = Y.mean()
-                 = Y0.mean() + B_c
-          var_Y = (Y - mean_Y)^2.mean()
-                = (Y0 - Y0.mean())^2.mean()
-        BN (r: bn.weight, beta: bn.bias):
-          Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
-            = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
-
-        Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
-          Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
-            = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
-
-        Fused Conv BN inference (running_std = sqrt(running_var + eps)):
-          Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
-
-        QAT with fused conv bn:
-          Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
-                  = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
-          Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
-        """
-
-        assert self.bn.running_var is not None
-        assert self.bn.running_mean is not None
-
-        # using zero bias here since the bias for original conv
-        # will be added later
-        zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.dtype)
-
-        weight_shape = [1] * len(self.weight.shape)
-        weight_shape[0] = -1
-        bias_shape = [1] * len(self.weight.shape)
-        bias_shape[1] = -1
-
-        if self.bn.training:
-            # needed to compute batch mean/std
-            conv_out = self._conv_forward(input, self.weight, zero_bias)
-            # update bn statistics
-            with torch.no_grad():
-                conv_out_bias = (
-                    conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
-                )
-                self.bn(conv_out_bias)
-
-        # fused conv + bn without bias using bn running statistics
-        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
-        scale_factor = self.bn.weight / running_std
-        scaled_weight = self.weight_fake_quant(
-            self.weight * scale_factor.reshape(weight_shape)
-        )
-        # fused conv without bias for inference: (r * W / running_std) * X
-        conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
-
-        if self.bn.training:
-            avg_dims = [0] + list(range(2, len(self.weight.shape)))
-            batch_mean = conv_out.mean(avg_dims)
-            batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
-                avg_dims
-            )
-            batch_std = torch.sqrt(batch_var + self.bn.eps)
-
-            # scale to use batch std in training mode
-            # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
-            unscale_factor = running_std / batch_std
-            conv_bn *= unscale_factor.reshape(bias_shape)
-
-            fused_mean = batch_mean
-            fused_std = batch_std
-        else:
-            fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
-            fused_std = running_std
-
-        # fused bias = beta - r * mean / std
-        fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
-        conv_bn += fused_bias.reshape(bias_shape)
-
-        # HACK to let conv bias particpiate in loss to avoid DDP error (parameters
-        #   were not used in producing loss)
-        if self.bias is not None:
-            conv_bn += (self.bias - self.bias).reshape(bias_shape)
-
-        return conv_bn
-
-    def extra_repr(self):
-        # TODO(jerryzh): extend
-        return super(_ConvBnNd, self).extra_repr()
-
-    def forward(self, input):
-        return self._forward(input)
-
-    def train(self, mode=True):
-        """
-        Batchnorm's training behavior is using the self.training flag. Prevent
-        changing it if BN is frozen. This makes sure that calling `model.train()`
-        on a model with a frozen BN will behave properly.
-        """
-        self.training = mode
-        if not self.freeze_bn:
-            for module in self.children():
-                module.train(mode)
-        return self
-
-    # ===== Serialization version history =====
-    #
-    # Version 1/None
-    #   self
-    #   |--- weight : Tensor
-    #   |--- bias : Tensor
-    #   |--- gamma : Tensor
-    #   |--- beta : Tensor
-    #   |--- running_mean : Tensor
-    #   |--- running_var : Tensor
-    #   |--- num_batches_tracked : Tensor
-    #
-    # Version 2
-    #   self
-    #   |--- weight : Tensor
-    #   |--- bias : Tensor
-    #   |--- bn : Module
-    #        |--- weight : Tensor (moved from v1.self.gamma)
-    #        |--- bias : Tensor (moved from v1.self.beta)
-    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
-    #        |--- running_var : Tensor (moved from v1.self.running_var)
-    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
-        version = local_metadata.get('version', None)
-        if version is None or version == 1:
-            # BN related parameters and buffers were moved into the BN module for v2
-            v2_to_v1_names = {
-                'bn.weight': 'gamma',
-                'bn.bias': 'beta',
-                'bn.running_mean': 'running_mean',
-                'bn.running_var': 'running_var',
-                'bn.num_batches_tracked': 'num_batches_tracked',
-            }
-            for v2_name, v1_name in v2_to_v1_names.items():
-                if prefix + v1_name in state_dict:
-                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
-                    state_dict.pop(prefix + v1_name)
-                elif prefix + v2_name in state_dict:
-                    # there was a brief period where forward compatibility
-                    # for this module was broken (between
-                    # https://github.com/pytorch/pytorch/pull/38478
-                    # and https://github.com/pytorch/pytorch/pull/38820)
-                    # and modules emitted the v2 state_dict format while
-                    # specifying that version == 1. This patches the forward
-                    # compatibility issue by allowing the v2 style entries to
-                    # be used.
-                    pass
-                elif strict:
-                    missing_keys.append(prefix + v2_name)
-
-        super(_ConvBnNd, self)._load_from_state_dict(
-            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
-
-    @classmethod
-    def from_float(cls, mod):
-        r"""Create a qat module from a float module or qparams_dict
-
-            Args: `mod` a float module, either produced by torch.ao.quantization utilities
-            or directly from user
-        """
-        # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
-        # has no __name__ (code is fine though)
-        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
-            cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
-        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
-        assert mod.qconfig, 'Input float module must have a valid qconfig'
-        qconfig = mod.qconfig
-        conv, bn = mod[0], mod[1]
-        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
-                         conv.stride, conv.padding, conv.dilation,
-                         conv.groups, conv.bias is not None,
-                         conv.padding_mode,
-                         bn.eps, bn.momentum,
-                         False,
-                         qconfig)
-        qat_convbn.weight = conv.weight
-        qat_convbn.bias = conv.bias
-        qat_convbn.bn.weight = bn.weight
-        qat_convbn.bn.bias = bn.bias
-        qat_convbn.bn.running_mean = bn.running_mean
-        qat_convbn.bn.running_var = bn.running_var
-        # mypy error: Cannot determine type of 'num_batches_tracked'
-        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked  # type: ignore[has-type]
-        return qat_convbn
-
-    def to_float(self):
-        cls = type(self)
-        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
-            self.in_channels,
-            self.out_channels,
-            self.kernel_size,
-            self.stride,
-            self.padding,
-            self.dilation,
-            self.groups,
-            self.bias is not None,
-            self.padding_mode)
-        conv.weight = torch.nn.Parameter(self.weight.detach())
-        if self.bias is not None:
-            conv.bias = torch.nn.Parameter(self.bias.detach())
-
-        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
-            # fuse bn into conv
-            conv.weight, conv.bias = fuse_conv_bn_weights(
-                conv.weight,
-                conv.bias,
-                self.bn.running_mean,
-                self.bn.running_var,
-                self.bn.eps,
-                self.bn.weight,
-                self.bn.bias
-            )
-
-        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
-            modules = []
-            modules.append(conv)
-            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
-            modules.append(relu)
-            conv_relu = cls._FUSED_FLOAT_MODULE(*modules)  # type: ignore[attr-defined]
-            conv_relu.train(self.training)
-            return conv_relu
-        else:
-            conv.train(self.training)
-            return conv
-
-class ConvBn1d(_ConvBnNd, nn.Conv1d):
-    r"""
-    A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv1d` and
-    :class:`torch.nn.BatchNorm1d`.
-
-    Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
-    to default.
-
-    Attributes:
-        freeze_bn:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_BN_MODULE = nn.BatchNorm1d
-    _FLOAT_RELU_MODULE = None
-    _FLOAT_MODULE = nni.ConvBn1d
-    _FLOAT_CONV_MODULE = nn.Conv1d
-
-    def __init__(self,
-                 # Conv1d args
-                 in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=None,
-                 padding_mode='zeros',
-                 # BatchNorm1d args
-                 # num_features: out_channels
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None):
-        kernel_size = _single(kernel_size)
-        stride = _single(stride)
-        padding = _single(padding)
-        dilation = _single(dilation)
-        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
-                           padding, dilation, False, _single(0), groups, bias, padding_mode,
-                           eps, momentum, freeze_bn, qconfig, dim=1)
-
-class ConvBnReLU1d(ConvBn1d):
-    r"""
-    A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv1d` and
-    :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
-
-    Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
-    default.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    # base class defines _FLOAT_MODULE as "ConvBn1d"
-    _FLOAT_MODULE = nni.ConvBnReLU1d  # type: ignore[assignment]
-    _FLOAT_CONV_MODULE = nn.Conv1d
-    _FLOAT_BN_MODULE = nn.BatchNorm1d
-    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
-    # module class after fusing bn into conv
-    _FUSED_FLOAT_MODULE = nni.ConvReLU1d
-
-    def __init__(self,
-                 # Conv1d args
-                 in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=None,
-                 padding_mode='zeros',
-                 # BatchNorm1d args
-                 # num_features: out_channels
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None):
-        super().__init__(in_channels, out_channels, kernel_size, stride,
-                         padding, dilation, groups, bias,
-                         padding_mode, eps, momentum,
-                         freeze_bn,
-                         qconfig)
-
-    def forward(self, input):
-        return F.relu(ConvBn1d._forward(self, input))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvBnReLU1d, cls).from_float(mod)
-
-class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
-    r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
-    FakeQuantize modules for weight for
-    quantization aware training.
-
-    We combined the interface of :class:`~torch.nn.Conv1d` and
-    :class:`~torch.nn.BatchNorm1d`.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvReLU1d
-    _FLOAT_CONV_MODULE = nn.Conv1d
-    _FLOAT_BN_MODULE = None
-    _FLOAT_RELU_MODULE = nn.ReLU
-
-    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=True, padding_mode='zeros',
-                 qconfig=None):
-        super(ConvReLU1d, self).__init__(in_channels, out_channels, kernel_size,
-                                         stride=stride, padding=padding, dilation=dilation,
-                                         groups=groups, bias=bias, padding_mode=padding_mode,
-                                         qconfig=qconfig)
-        assert qconfig, 'qconfig must be provided for QAT module'
-        self.qconfig = qconfig
-        self.weight_fake_quant = self.qconfig.weight()
-
-    def forward(self, input):
-        return F.relu(
-            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvReLU1d, cls).from_float(mod)
-
-class ConvBn2d(_ConvBnNd, nn.Conv2d):
-    r"""
-    A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv2d` and
-    :class:`torch.nn.BatchNorm2d`.
-
-    Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
-    to default.
-
-    Attributes:
-        freeze_bn:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvBn2d
-    _FLOAT_CONV_MODULE = nn.Conv2d
-    _FLOAT_BN_MODULE = nn.BatchNorm2d
-    _FLOAT_RELU_MODULE = None
-
-    def __init__(self,
-                 # ConvNd args
-                 in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=None,
-                 padding_mode='zeros',
-                 # BatchNorm2d args
-                 # num_features: out_channels
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None):
-        kernel_size = _pair(kernel_size)
-        stride = _pair(stride)
-        padding = _pair(padding)
-        dilation = _pair(dilation)
-        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
-                           padding, dilation, False, _pair(0), groups, bias, padding_mode,
-                           eps, momentum, freeze_bn, qconfig, dim=2)
-
-class ConvBnReLU2d(ConvBn2d):
-    r"""
-    A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv2d` and
-    :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
-
-    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
-    default.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    # base class defines _FLOAT_MODULE as "ConvBn2d"
-    _FLOAT_MODULE = nni.ConvBnReLU2d  # type: ignore[assignment]
-    _FLOAT_CONV_MODULE = nn.Conv2d
-    _FLOAT_BN_MODULE = nn.BatchNorm2d
-    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
-    # module class after fusing bn into conv
-    _FUSED_FLOAT_MODULE = nni.ConvReLU2d
-
-    def __init__(self,
-                 # Conv2d args
-                 in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=None,
-                 padding_mode='zeros',
-                 # BatchNorm2d args
-                 # num_features: out_channels
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None):
-        super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
-                                           padding, dilation, groups, bias,
-                                           padding_mode, eps, momentum,
-                                           freeze_bn,
-                                           qconfig)
-
-    def forward(self, input):
-        return F.relu(ConvBn2d._forward(self, input))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvBnReLU2d, cls).from_float(mod)
-
-class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
-    r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
-    FakeQuantize modules for weight for
-    quantization aware training.
-
-    We combined the interface of :class:`~torch.nn.Conv2d` and
-    :class:`~torch.nn.BatchNorm2d`.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvReLU2d
-    _FLOAT_CONV_MODULE = nn.Conv2d
-    _FLOAT_BN_MODULE = None
-    _FLOAT_RELU_MODULE = nn.ReLU
-
-    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1,
-                 bias=True, padding_mode='zeros',
-                 qconfig=None):
-        super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
-                                         stride=stride, padding=padding, dilation=dilation,
-                                         groups=groups, bias=bias, padding_mode=padding_mode,
-                                         qconfig=qconfig)
-        assert qconfig, 'qconfig must be provided for QAT module'
-        self.qconfig = qconfig
-        self.weight_fake_quant = self.qconfig.weight()
-
-    def forward(self, input):
-        return F.relu(
-            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvReLU2d, cls).from_float(mod)
-
-class ConvBn3d(_ConvBnNd, nn.Conv3d):
-    r"""
-    A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv3d` and
-    :class:`torch.nn.BatchNorm3d`.
-
-    Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
-    to default.
-
-    Attributes:
-        freeze_bn:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvBn3d
-    _FLOAT_CONV_MODULE = nn.Conv3d
-    _FLOAT_BN_MODULE = nn.BatchNorm3d
-    _FLOAT_RELU_MODULE = None
-
-    def __init__(
-        self,
-        # ConvNd args
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride=1,
-        padding=0,
-        dilation=1,
-        groups=1,
-        bias=None,
-        padding_mode="zeros",
-        # BatchNorm3d args
-        # num_features: out_channels
-        eps=1e-05,
-        momentum=0.1,
-        # affine: True
-        # track_running_stats: True
-        # Args for this module
-        freeze_bn=False,
-        qconfig=None,
-    ):
-        kernel_size = _triple(kernel_size)
-        stride = _triple(stride)
-        padding = _triple(padding)
-        dilation = _triple(dilation)
-        _ConvBnNd.__init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride,
-            padding,
-            dilation,
-            False,
-            _triple(0),
-            groups,
-            bias,
-            padding_mode,
-            eps,
-            momentum,
-            freeze_bn,
-            qconfig,
-            dim=3,
-        )
-
-class ConvBnReLU3d(ConvBn3d):
-    r"""
-    A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
-    attached with FakeQuantize modules for weight,
-    used in quantization aware training.
-
-    We combined the interface of :class:`torch.nn.Conv3d` and
-    :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
-
-    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
-    default.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvBnReLU3d  # type: ignore[assignment]
-    _FLOAT_CONV_MODULE = nn.Conv3d
-    _FLOAT_BN_MODULE = nn.BatchNorm3d
-    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
-    # module class after fusing bn into conv
-    _FUSED_FLOAT_MODULE = nni.ConvReLU3d
-
-    def __init__(
-        self,
-        # Conv3d args
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride=1,
-        padding=0,
-        dilation=1,
-        groups=1,
-        bias=None,
-        padding_mode="zeros",
-        # BatchNorm3d args
-        # num_features: out_channels
-        eps=1e-05,
-        momentum=0.1,
-        # affine: True
-        # track_running_stats: True
-        # Args for this module
-        freeze_bn=False,
-        qconfig=None,
-    ):
-        super(ConvBnReLU3d, self).__init__(
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride,
-            padding,
-            dilation,
-            groups,
-            bias,
-            padding_mode,
-            eps,
-            momentum,
-            freeze_bn,
-            qconfig,
-        )
-
-    def forward(self, input):
-        return F.relu(ConvBn3d._forward(self, input))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvBnReLU3d, cls).from_float(mod)
-
-class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
-    r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
-    FakeQuantize modules for weight for
-    quantization aware training.
-
-    We combined the interface of :class:`~torch.nn.Conv3d` and
-    :class:`~torch.nn.BatchNorm3d`.
-
-    Attributes:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    _FLOAT_MODULE = nni.ConvReLU3d
-    _FLOAT_CONV_MODULE = nn.Conv3d
-    _FLOAT_BN_MODULE = None
-    _FLOAT_RELU_MODULE = nn.ReLU
-
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride=1,
-        padding=0,
-        dilation=1,
-        groups=1,
-        bias=True,
-        padding_mode="zeros",
-        qconfig=None,
-    ):
-        super(ConvReLU3d, self).__init__(
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-            bias=bias,
-            padding_mode=padding_mode,
-            qconfig=qconfig,
-        )
-        assert qconfig, "qconfig must be provided for QAT module"
-        self.qconfig = qconfig
-        self.weight_fake_quant = self.qconfig.weight()
-
-    def forward(self, input):
-        return F.relu(
-            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
-        )
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(ConvReLU3d, cls).from_float(mod)
-
-def update_bn_stats(mod):
-    if type(mod) in set(
-        [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
-    ):
-        mod.update_bn_stats()
-
-def freeze_bn_stats(mod):
-    if type(mod) in set(
-        [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
-    ):
-        mod.freeze_bn_stats()
+from torch.ao.nn.intrinsic.qat import ConvBn1d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU1d
+from torch.ao.nn.intrinsic.qat import ConvReLU1d
+from torch.ao.nn.intrinsic.qat import ConvBn2d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU2d
+from torch.ao.nn.intrinsic.qat import ConvReLU2d
+from torch.ao.nn.intrinsic.qat import ConvBn3d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU3d
+from torch.ao.nn.intrinsic.qat import ConvReLU3d
+from torch.ao.nn.intrinsic.qat import freeze_bn_stats
+from torch.ao.nn.intrinsic.qat import update_bn_stats
diff --git a/torch/nn/intrinsic/qat/modules/linear_fused.py b/torch/nn/intrinsic/qat/modules/linear_fused.py
index f19dbd9..57fbf83 100644
--- a/torch/nn/intrinsic/qat/modules/linear_fused.py
+++ b/torch/nn/intrinsic/qat/modules/linear_fused.py
@@ -1,167 +1,15 @@
-import torch
-import torch.nn as nn
-import torch.ao.nn.intrinsic as nni
-import torch.nn.functional as F
-from torch.nn import init
-from torch.nn.parameter import Parameter
-from torch.nn.utils.fusion import fuse_linear_bn_weights
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
 
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
 
-class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
-    r"""
-    A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
-    with FakeQuantize modules for weight, used in quantization aware training.
+__all__ = [
+    'LinearBn1d',
+]
 
-    We combined the interface of :class:`torch.nn.Linear` and
-    :class:torch.nn.BatchNorm1d`.
-
-    Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
-    to default.
-
-    Attributes:
-        freeze_bn:
-        weight_fake_quant: fake quant module for weight
-
-    """
-    def __init__(self,
-                 # Linear args
-                 in_features, out_features, bias=True,
-                 # BatchNorm1d args
-                 # num_features: out_features
-                 eps=1e-05, momentum=0.1,
-                 # affine: True
-                 # track_running_stats: True
-                 # Args for this module
-                 freeze_bn=False,
-                 qconfig=None):
-        nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
-        assert qconfig, 'qconfig must be provded for QAT module'
-        self.qconfig = qconfig
-        self.freeze_bn = freeze_bn if self.training else True
-        self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
-        self.weight_fake_quant = self.qconfig.weight()
-        if bias:
-            self.bias = Parameter(torch.empty(out_features))
-        else:
-            self.register_parameter('bias', None)
-        self.reset_bn_parameters()
-
-        # this needs to be called after reset_bn_parameters,
-        # as they modify the same state
-        if self.training:
-            if freeze_bn:
-                self.freeze_bn_stats()
-            else:
-                self.update_bn_stats()
-        else:
-            self.freeze_bn_stats()
-
-    def reset_running_stats(self):
-        self.bn.reset_running_stats()
-
-    def reset_bn_parameters(self):
-        self.bn.reset_running_stats()
-        init.uniform_(self.bn.weight)
-        init.zeros_(self.bn.bias)
-
-    def reset_parameters(self):
-        super(LinearBn1d, self).reset_parameters()
-
-    def update_bn_stats(self):
-        self.freeze_bn = False
-        self.bn.training = True
-        return self
-
-    def freeze_bn_stats(self):
-        self.freeze_bn = True
-        self.bn.training = False
-        return self
-
-    def forward(self, input):
-        assert self.bn.running_var is not None
-
-        # Scale the linear weights by BN's running statistics to reduce
-        # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
-        # for motivation.
-        #
-        # Instead of
-        #
-        #   x1 = F.linear(x0, fq(w), b)
-        #   x2 = self.bn(x1)
-        #
-        # We have
-        #
-        #   # scale the weight by previous batch's running statistics
-        #   scale_factor = bn.w / bn.running_std_from_prev_batch
-        #   # do the linear transformation without bias
-        #   x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
-        #   # reverse the scaling and add original bias
-        #   x1_orig = x1_scaled / scale_factor + b
-        #   x2 = self.bn(x1_orig)
-
-        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
-        scale_factor = self.bn.weight / running_std
-        weight_shape = [1] * len(self.weight.shape)
-        weight_shape[0] = -1
-        bias_shape = [1] * len(self.weight.shape)
-        bias_shape[1] = -1
-        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
-        if self.bias is not None:
-            zero_bias = torch.zeros_like(self.bias)
-        else:
-            zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
-        linear_out = F.linear(input, scaled_weight, zero_bias)
-        linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
-        if self.bias is not None:
-            linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
-        bn_out = self.bn(linear_out_orig)
-        return bn_out
-
-    def train(self, mode=True):
-        """
-        Batchnorm's training behavior is using the self.training flag. Prevent
-        changing it if BN is frozen. This makes sure that calling `model.train()`
-        on a model with a frozen BN will behave properly.
-        """
-        self.training = mode
-        if not self.freeze_bn:
-            for module in self.children():
-                module.train(mode)
-        return self
-
-    @classmethod
-    def from_float(cls, mod):
-        r"""Create a qat module from a float module or qparams_dict
-
-            Args: `mod' a float module, either produced by torch.ao.quantization
-            utilities or directly from user
-        """
-        assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
-            '.from_float only works for ' + nni.LinearBn1d.__name__
-        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
-        assert mod.qconfig, 'Input float module must have a valid config'
-        qconfig = mod.qconfig
-        linear, bn = mod[0], mod[1]
-        qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
-                           bn.eps, bn.momentum,
-                           False, qconfig)
-        qat_linearbn.weight = linear.weight
-        qat_linearbn.bias = linear.bias
-        qat_linearbn.bn.weight = bn.weight
-        qat_linearbn.bn.bias = bn.bias
-        qat_linearbn.bn.running_mean = bn.running_mean
-        qat_linearbn.bn.running_var = bn.running_var
-        qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
-        return qat_linearbn
-
-    def to_float(self):
-        linear = torch.nn.Linear(self.in_features, self.out_features)
-        linear.weight, linear.bias = fuse_linear_bn_weights(
-            self.weight,
-            self.bias,
-            self.bn.running_mean,
-            self.bn.running_var,
-            self.bn.eps,
-            self.bn.weight,
-            self.bn.bias)
-        return linear
+from torch.ao.nn.intrinsic.qat import LinearBn1d
diff --git a/torch/nn/intrinsic/qat/modules/linear_relu.py b/torch/nn/intrinsic/qat/modules/linear_relu.py
index 1c77965..45afc33 100644
--- a/torch/nn/intrinsic/qat/modules/linear_relu.py
+++ b/torch/nn/intrinsic/qat/modules/linear_relu.py
@@ -1,48 +1,15 @@
-import torch
-import torch.ao.nn.qat as nnqat
-import torch.ao.nn.intrinsic as nni
-import torch.nn.functional as F
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
 
-class LinearReLU(nnqat.Linear, nni._FusedModule):
-    r"""
-    A LinearReLU module fused from Linear and ReLU modules, attached with
-    FakeQuantize modules for weight, used in
-    quantization aware training.
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
 
-    We adopt the same interface as :class:`torch.nn.Linear`.
+__all__ = [
+    'LinearReLU',
+]
 
-    Similar to `torch.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
-    default.
-
-    Attributes:
-        weight: fake quant module for weight
-
-    Examples::
-
-        >>> # xdoctest: +SKIP
-        >>> m = nn.qat.LinearReLU(20, 30)
-        >>> input = torch.randn(128, 20)
-        >>> output = m(input)
-        >>> print(output.size())
-        torch.Size([128, 30])
-    """
-    _FLOAT_MODULE = nni.LinearReLU
-
-    def __init__(self, in_features, out_features, bias=True,
-                 qconfig=None):
-        super(LinearReLU, self).__init__(in_features, out_features, bias, qconfig)
-
-    def forward(self, input):
-        return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
-
-    @classmethod
-    def from_float(cls, mod):
-        return super(LinearReLU, cls).from_float(mod)
-
-    def to_float(self):
-        linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
-        linear.weight = torch.nn.Parameter(self.weight.detach())
-        if self.bias is not None:
-            linear.bias = torch.nn.Parameter(self.bias.detach())
-        relu = torch.nn.ReLU()
-        return torch.nn.intrinsic.LinearReLU(linear, relu)
+from torch.ao.nn.intrinsic.qat import LinearReLU
diff --git a/torch/nn/intrinsic/quantized/modules/bn_relu.py b/torch/nn/intrinsic/quantized/modules/bn_relu.py
index 62e4f4f..d8e6fcc 100644
--- a/torch/nn/intrinsic/quantized/modules/bn_relu.py
+++ b/torch/nn/intrinsic/quantized/modules/bn_relu.py
@@ -1,7 +1,7 @@
 
 import torch
 import torch.ao.nn.intrinsic
-import torch.nn.intrinsic.qat
+import torch.ao.nn.intrinsic.qat
 import torch.ao.nn.quantized as nnq
 
 
diff --git a/torch/nn/intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py
index 022c97e..ba6691a 100644
--- a/torch/nn/intrinsic/quantized/modules/conv_relu.py
+++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py
@@ -1,7 +1,7 @@
 
 import torch
 import torch.ao.nn.intrinsic
-import torch.nn.intrinsic.qat
+import torch.ao.nn.intrinsic.qat
 import torch.nn.functional as F
 import torch.ao.nn.quantized as nnq
 
@@ -48,7 +48,7 @@
 
     @classmethod
     def from_float(cls, mod):
-        if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU1d:
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
             mod.weight, mod.bias = fuse_conv_bn_weights(
                 mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
                 mod.bn.eps, mod.bn.weight, mod.bn.bias)
@@ -97,7 +97,7 @@
 
     @classmethod
     def from_float(cls, mod):
-        if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU2d:
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
             mod.weight, mod.bias = fuse_conv_bn_weights(
                 mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
                 mod.bn.eps, mod.bn.weight, mod.bn.bias)
@@ -147,7 +147,7 @@
 
     @classmethod
     def from_float(cls, mod):
-        if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU3d:
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
             mod.weight, mod.bias = fuse_conv_bn_weights(
                 mod.weight,
                 mod.bias,