[quant][ao_migration] nn.intrinsic.qat migration to ao (#86171)
All quantization-related modules are being migrated to `torch.ao`. This migrates the `nn.intrinsic.qat`. Please, see the [tracker](https://github.com/pytorch/pytorch/issues/81667) for the timeline.
```
python test/test_quantization.py TestAOMigrationNNIntrinsic
```
Differential Revision: [D39419993](https://our.internmc.facebook.com/intern/diff/D39419993/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39419993/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86171
Approved by: https://github.com/jerryzh168
diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst
index a681e49..dfcc6f6 100644
--- a/docs/source/quantization-support.rst
+++ b/docs/source/quantization-support.rst
@@ -296,16 +296,16 @@
BNReLU2d
BNReLU3d
-torch.nn.intrinsic.qat
-~~~~~~~~~~~~~~~~~~~~~~
-.. automodule:: torch.nn.intrinsic.qat
-.. automodule:: torch.nn.intrinsic.qat.modules
+torch.ao.nn.intrinsic.qat
+~~~~~~~~~~~~~~~~~~~~~~~~~
+.. automodule:: torch.ao.nn.intrinsic.qat
+.. automodule:: torch.ao.nn.intrinsic.qat.modules
This module implements the versions of those fused operations needed for
quantization aware training.
-.. currentmodule:: torch.nn.intrinsic.qat
+.. currentmodule:: torch.ao.nn.intrinsic.qat
.. autosummary::
:toctree: generated
@@ -597,3 +597,6 @@
:noindex:
.. automodule:: torch.ao.nn.quantized.reference.modules
:noindex:
+
+.. py:module:: torch.nn.intrinsic.qat
+.. py:module:: torch.nn.intrinsic.qat.modules
diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json
index d492734..02e707d 100644
--- a/test/allowlist_for_publicAPI.json
+++ b/test/allowlist_for_publicAPI.json
@@ -3,6 +3,11 @@
"torch.nn.intrinsic": "torch.ao.nn.intrinsic",
"torch.nn.intrinsic.modules": "torch.ao.nn.intrinsic.modules",
"torch.nn.intrinsic.modules.fused": "torch.ao.nn.intrinsic.modules.fused",
+ "torch.nn.intrinsic.qat": "torch.ao.nn.intrinsic.qat",
+ "torch.nn.intrinsic.qat.modules": "torch.ao.nn.intrinsic.qat.modules",
+ "torch.nn.intrinsic.qat.modules.conv_fused": "torch.ao.nn.intrinsic.qat.modules.conv_fused",
+ "torch.nn.intrinsic.qat.modules.linear_fused": "torch.ao.nn.intrinsic.qat.modules.linear_fused",
+ "torch.nn.intrinsic.qat.modules.linear_relu": "torch.ao.nn.intrinsic.qat.modules.linear_relu",
"torch.nn.qat": "torch.ao.nn.qat",
"torch.nn.qat.dynamic": "torch.ao.nn.qat.dynamic",
"torch.nn.qat.dynamic.modules": "torch.ao.nn.qat.dynamic.modules",
diff --git a/test/ao/sparsity/test_composability.py b/test/ao/sparsity/test_composability.py
index 366116e..f531dd2 100644
--- a/test/ao/sparsity/test_composability.py
+++ b/test/ao/sparsity/test_composability.py
@@ -514,7 +514,7 @@
# that none were lost during prepare
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
- self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU))
+ self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
# check that correct observers were inserted and that matching
# occured successfully
diff --git a/test/quantization/ao_migration/test_ao_migration.py b/test/quantization/ao_migration/test_ao_migration.py
index 6b76a01..3d0e3b3 100644
--- a/test/quantization/ao_migration/test_ao_migration.py
+++ b/test/quantization/ao_migration/test_ao_migration.py
@@ -382,7 +382,6 @@
def test_package_import_nn_intrinsic(self):
skip = [
- 'qat',
'quantized',
]
self._test_package_import('intrinsic', base='nn', skip=skip)
@@ -407,7 +406,7 @@
]
self._test_function_import('intrinsic', module_list, base='nn')
- def test_modules_fused(self):
+ def test_modules_nn_intrinsic_fused(self):
function_list = [
'_FusedModule',
'ConvBn1d',
@@ -426,3 +425,57 @@
]
self._test_function_import('fused', function_list,
base='nn.intrinsic.modules')
+
+ def test_package_import_nn_intrinsic_qat(self):
+ r"""Tests the migration of the torch.nn.intrinsic.modules"""
+ self._test_package_import('qat', base='nn.intrinsic')
+ self._test_package_import('qat.modules', base='nn.intrinsic')
+
+ def test_modules_import_nn_intrinsic_qat(self):
+ module_list = [
+ "LinearReLU",
+ "LinearBn1d",
+ "ConvReLU1d",
+ "ConvReLU2d",
+ "ConvReLU3d",
+ "ConvBn1d",
+ "ConvBn2d",
+ "ConvBn3d",
+ "ConvBnReLU1d",
+ "ConvBnReLU2d",
+ "ConvBnReLU3d",
+ "update_bn_stats",
+ "freeze_bn_stats",
+ ]
+ self._test_function_import('qat', module_list, base='nn.intrinsic')
+
+ def test_modules_intrinsic_qat_conv_fused(self):
+ function_list = [
+ 'ConvBn1d',
+ 'ConvBnReLU1d',
+ 'ConvReLU1d',
+ 'ConvBn2d',
+ 'ConvBnReLU2d',
+ 'ConvReLU2d',
+ 'ConvBn3d',
+ 'ConvBnReLU3d',
+ 'ConvReLU3d',
+ 'update_bn_stats',
+ 'freeze_bn_stats'
+ ]
+ self._test_function_import('conv_fused', function_list,
+ base='nn.intrinsic.qat.modules')
+
+ def test_modules_intrinsic_qat_linear_fused(self):
+ function_list = [
+ 'LinearBn1d',
+ ]
+ self._test_function_import('linear_fused', function_list,
+ base='nn.intrinsic.qat.modules')
+
+ def test_modules_intrinsic_qat_linear_relu(self):
+ function_list = [
+ 'LinearReLU',
+ ]
+ self._test_function_import('linear_relu', function_list,
+ base='nn.intrinsic.qat.modules')
diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py
index ba93a79..5b46e6a 100644
--- a/test/quantization/core/test_workflow_module.py
+++ b/test/quantization/core/test_workflow_module.py
@@ -864,7 +864,7 @@
if epoch >= 1:
model.apply(torch.ao.quantization.disable_observer)
if epoch >= 2:
- model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
quant_model = copy.deepcopy(model.module)
quant_model = torch.ao.quantization.convert(quant_model.eval().cpu(), inplace=False)
with torch.no_grad():
diff --git a/test/quantization/eager/test_fuse_eager.py b/test/quantization/eager/test_fuse_eager.py
index e397fa9..9f120b8 100644
--- a/test/quantization/eager/test_fuse_eager.py
+++ b/test/quantization/eager/test_fuse_eager.py
@@ -7,7 +7,7 @@
import torch.ao.nn.quantized as nnq
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
from torch.ao.quantization import (
quantize,
prepare,
diff --git a/test/quantization/eager/test_model_numerics.py b/test/quantization/eager/test_model_numerics.py
index b259e10..bcefb78 100644
--- a/test/quantization/eager/test_model_numerics.py
+++ b/test/quantization/eager/test_model_numerics.py
@@ -73,7 +73,7 @@
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)
- fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
fq_model(calib_data)
fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer)
@@ -109,7 +109,7 @@
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)
- fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
fq_model(calib_data)
fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer)
diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py
index c625425..bc118a8 100644
--- a/test/quantization/eager/test_quantize_eager_qat.py
+++ b/test/quantization/eager/test_quantize_eager_qat.py
@@ -6,12 +6,12 @@
import torch.nn as nn
import torch.backends.mkldnn
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
-from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
+from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.qat as nnqat
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.quantization import (
prepare,
@@ -819,9 +819,9 @@
qat_op.apply(torch.ao.quantization.disable_fake_quant)
if freeze_bn:
- qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
else:
- qat_op.apply(torch.nn.intrinsic.qat.update_bn_stats)
+ qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
@@ -992,7 +992,7 @@
input_clone = input.clone().detach().requires_grad_()
if i > 2:
- qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
qat_ref_op.freeze_bn_stats()
if i > 3:
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 595f265..4a49a07 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -746,7 +746,7 @@
],
}
prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
- self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.nn.intrinsic.qat.LinearReLU))
+ self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU))
def _get_conv_linear_test_cases(self, is_reference):
""" Returns a list of test cases, with format:
diff --git a/torch/ao/nn/intrinsic/__init__.py b/torch/ao/nn/intrinsic/__init__.py
index c919966..ef2582b 100644
--- a/torch/ao/nn/intrinsic/__init__.py
+++ b/torch/ao/nn/intrinsic/__init__.py
@@ -1,6 +1,9 @@
from .modules import * # noqa: F403
from .modules.fused import _FusedModule # noqa: F403
+# Subpackages
+from . import qat # noqa: F403
+
__all__ = [
'ConvBn1d',
'ConvBn2d',
diff --git a/torch/ao/nn/intrinsic/qat/__init__.py b/torch/ao/nn/intrinsic/qat/__init__.py
new file mode 100644
index 0000000..3d79bdb
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/__init__.py
@@ -0,0 +1 @@
+from .modules import * # noqa: F403
diff --git a/torch/ao/nn/intrinsic/qat/modules/__init__.py b/torch/ao/nn/intrinsic/qat/modules/__init__.py
new file mode 100644
index 0000000..f44820c
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/__init__.py
@@ -0,0 +1,31 @@
+from .linear_relu import LinearReLU
+from .linear_fused import LinearBn1d
+from .conv_fused import (
+ ConvBn1d,
+ ConvBn2d,
+ ConvBn3d,
+ ConvBnReLU1d,
+ ConvBnReLU2d,
+ ConvBnReLU3d,
+ ConvReLU1d,
+ ConvReLU2d,
+ ConvReLU3d,
+ update_bn_stats,
+ freeze_bn_stats,
+)
+
+__all__ = [
+ "LinearReLU",
+ "LinearBn1d",
+ "ConvReLU1d",
+ "ConvReLU2d",
+ "ConvReLU3d",
+ "ConvBn1d",
+ "ConvBn2d",
+ "ConvBn3d",
+ "ConvBnReLU1d",
+ "ConvBnReLU2d",
+ "ConvBnReLU3d",
+ "update_bn_stats",
+ "freeze_bn_stats",
+]
diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
new file mode 100644
index 0000000..6a6f4c1
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
@@ -0,0 +1,828 @@
+import math
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.qat as nnqat
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.utils import fuse_conv_bn_weights
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.nn.parameter import Parameter
+from typing import TypeVar
+
+__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
+ 'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
+_BN_CLASS_MAP = {
+ 1: nn.BatchNorm1d,
+ 2: nn.BatchNorm2d,
+ 3: nn.BatchNorm3d,
+}
+
+
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
+
+
+class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
+
+ _version = 2
+ _FLOAT_MODULE = MOD
+
+ def __init__(self,
+ # ConvNd args
+ in_channels, out_channels, kernel_size, stride,
+ padding, dilation, transposed, output_padding,
+ groups,
+ bias,
+ padding_mode,
+ # BatchNormNd args
+ # num_features: out_channels
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None,
+ dim=2):
+ nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, transposed,
+ output_padding, groups, False, padding_mode)
+ assert qconfig, 'qconfig must be provided for QAT module'
+ self.qconfig = qconfig
+ self.freeze_bn = freeze_bn if self.training else True
+ self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
+ self.weight_fake_quant = self.qconfig.weight()
+ if bias:
+ self.bias = Parameter(torch.empty(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_bn_parameters()
+
+ # this needs to be called after reset_bn_parameters,
+ # as they modify the same state
+ if self.training:
+ if freeze_bn:
+ self.freeze_bn_stats()
+ else:
+ self.update_bn_stats()
+ else:
+ self.freeze_bn_stats()
+
+ self._enable_slow_path_for_better_numerical_stability = False
+
+ def reset_running_stats(self):
+ self.bn.reset_running_stats()
+
+ def reset_bn_parameters(self):
+ self.bn.reset_running_stats()
+ init.uniform_(self.bn.weight)
+ init.zeros_(self.bn.bias)
+ # note: below is actully for conv, not BN
+ if self.bias is not None:
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ init.uniform_(self.bias, -bound, bound)
+
+ def reset_parameters(self):
+ super(_ConvBnNd, self).reset_parameters()
+
+ def update_bn_stats(self):
+ self.freeze_bn = False
+ self.bn.training = True
+ return self
+
+ def freeze_bn_stats(self):
+ self.freeze_bn = True
+ self.bn.training = False
+ return self
+
+ def _forward(self, input):
+ if self._enable_slow_path_for_better_numerical_stability:
+ return self._forward_slow(input)
+ return self._forward_approximate(input)
+
+ def _forward_approximate(self, input):
+ """Approximated method to fuse conv and bn. It requires only one forward pass.
+ conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
+ """
+ assert self.bn.running_var is not None
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+ scale_factor = self.bn.weight / running_std
+ weight_shape = [1] * len(self.weight.shape)
+ weight_shape[0] = -1
+ bias_shape = [1] * len(self.weight.shape)
+ bias_shape[1] = -1
+ scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+ # using zero bias here since the bias for original conv
+ # will be added later
+ if self.bias is not None:
+ zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
+ else:
+ zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
+ conv = self._conv_forward(input, scaled_weight, zero_bias)
+ conv_orig = conv / scale_factor.reshape(bias_shape)
+ if self.bias is not None:
+ conv_orig = conv_orig + self.bias.reshape(bias_shape)
+ conv = self.bn(conv_orig)
+ return conv
+
+ def _forward_slow(self, input):
+ """
+ A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
+ It requires two forward passes but handles the case bn.weight == 0
+
+ Conv: Y = WX + B_c
+ Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
+
+ Batch statistics:
+ mean_Y = Y.mean()
+ = Y0.mean() + B_c
+ var_Y = (Y - mean_Y)^2.mean()
+ = (Y0 - Y0.mean())^2.mean()
+ BN (r: bn.weight, beta: bn.bias):
+ Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
+ = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
+
+ Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
+ Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
+ = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
+
+ Fused Conv BN inference (running_std = sqrt(running_var + eps)):
+ Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
+
+ QAT with fused conv bn:
+ Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+ = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+ Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
+ """
+
+ assert self.bn.running_var is not None
+ assert self.bn.running_mean is not None
+
+ # using zero bias here since the bias for original conv
+ # will be added later
+ zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.dtype)
+
+ weight_shape = [1] * len(self.weight.shape)
+ weight_shape[0] = -1
+ bias_shape = [1] * len(self.weight.shape)
+ bias_shape[1] = -1
+
+ if self.bn.training:
+ # needed to compute batch mean/std
+ conv_out = self._conv_forward(input, self.weight, zero_bias)
+ # update bn statistics
+ with torch.no_grad():
+ conv_out_bias = (
+ conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
+ )
+ self.bn(conv_out_bias)
+
+ # fused conv + bn without bias using bn running statistics
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+ scale_factor = self.bn.weight / running_std
+ scaled_weight = self.weight_fake_quant(
+ self.weight * scale_factor.reshape(weight_shape)
+ )
+ # fused conv without bias for inference: (r * W / running_std) * X
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
+
+ if self.bn.training:
+ avg_dims = [0] + list(range(2, len(self.weight.shape)))
+ batch_mean = conv_out.mean(avg_dims)
+ batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
+ avg_dims
+ )
+ batch_std = torch.sqrt(batch_var + self.bn.eps)
+
+ # scale to use batch std in training mode
+ # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
+ unscale_factor = running_std / batch_std
+ conv_bn *= unscale_factor.reshape(bias_shape)
+
+ fused_mean = batch_mean
+ fused_std = batch_std
+ else:
+ fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
+ fused_std = running_std
+
+ # fused bias = beta - r * mean / std
+ fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
+ conv_bn += fused_bias.reshape(bias_shape)
+
+ # HACK to let conv bias particpiate in loss to avoid DDP error (parameters
+ # were not used in producing loss)
+ if self.bias is not None:
+ conv_bn += (self.bias - self.bias).reshape(bias_shape)
+
+ return conv_bn
+
+ def extra_repr(self):
+ # TODO(jerryzh): extend
+ return super(_ConvBnNd, self).extra_repr()
+
+ def forward(self, input):
+ return self._forward(input)
+
+ def train(self, mode=True):
+ """
+ Batchnorm's training behavior is using the self.training flag. Prevent
+ changing it if BN is frozen. This makes sure that calling `model.train()`
+ on a model with a frozen BN will behave properly.
+ """
+ self.training = mode
+ if not self.freeze_bn:
+ for module in self.children():
+ module.train(mode)
+ return self
+
+ # ===== Serialization version history =====
+ #
+ # Version 1/None
+ # self
+ # |--- weight : Tensor
+ # |--- bias : Tensor
+ # |--- gamma : Tensor
+ # |--- beta : Tensor
+ # |--- running_mean : Tensor
+ # |--- running_var : Tensor
+ # |--- num_batches_tracked : Tensor
+ #
+ # Version 2
+ # self
+ # |--- weight : Tensor
+ # |--- bias : Tensor
+ # |--- bn : Module
+ # |--- weight : Tensor (moved from v1.self.gamma)
+ # |--- bias : Tensor (moved from v1.self.beta)
+ # |--- running_mean : Tensor (moved from v1.self.running_mean)
+ # |--- running_var : Tensor (moved from v1.self.running_var)
+ # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+ if version is None or version == 1:
+ # BN related parameters and buffers were moved into the BN module for v2
+ v2_to_v1_names = {
+ 'bn.weight': 'gamma',
+ 'bn.bias': 'beta',
+ 'bn.running_mean': 'running_mean',
+ 'bn.running_var': 'running_var',
+ 'bn.num_batches_tracked': 'num_batches_tracked',
+ }
+ for v2_name, v1_name in v2_to_v1_names.items():
+ if prefix + v1_name in state_dict:
+ state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
+ state_dict.pop(prefix + v1_name)
+ elif prefix + v2_name in state_dict:
+ # there was a brief period where forward compatibility
+ # for this module was broken (between
+ # https://github.com/pytorch/pytorch/pull/38478
+ # and https://github.com/pytorch/pytorch/pull/38820)
+ # and modules emitted the v2 state_dict format while
+ # specifying that version == 1. This patches the forward
+ # compatibility issue by allowing the v2 style entries to
+ # be used.
+ pass
+ elif strict:
+ missing_keys.append(prefix + v2_name)
+
+ super(_ConvBnNd, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ @classmethod
+ def from_float(cls, mod):
+ r"""Create a qat module from a float module or qparams_dict
+
+ Args: `mod` a float module, either produced by torch.ao.quantization utilities
+ or directly from user
+ """
+ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
+ # has no __name__ (code is fine though)
+ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
+ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+ assert mod.qconfig, 'Input float module must have a valid qconfig'
+ qconfig = mod.qconfig
+ conv, bn = mod[0], mod[1]
+ qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
+ conv.stride, conv.padding, conv.dilation,
+ conv.groups, conv.bias is not None,
+ conv.padding_mode,
+ bn.eps, bn.momentum,
+ False,
+ qconfig)
+ qat_convbn.weight = conv.weight
+ qat_convbn.bias = conv.bias
+ qat_convbn.bn.weight = bn.weight
+ qat_convbn.bn.bias = bn.bias
+ qat_convbn.bn.running_mean = bn.running_mean
+ qat_convbn.bn.running_var = bn.running_var
+ # mypy error: Cannot determine type of 'num_batches_tracked'
+ qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type]
+ return qat_convbn
+
+ def to_float(self):
+ cls = type(self)
+ conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
+ self.in_channels,
+ self.out_channels,
+ self.kernel_size,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.bias is not None,
+ self.padding_mode)
+ conv.weight = torch.nn.Parameter(self.weight.detach())
+ if self.bias is not None:
+ conv.bias = torch.nn.Parameter(self.bias.detach())
+
+ if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
+ # fuse bn into conv
+ conv.weight, conv.bias = fuse_conv_bn_weights(
+ conv.weight,
+ conv.bias,
+ self.bn.running_mean,
+ self.bn.running_var,
+ self.bn.eps,
+ self.bn.weight,
+ self.bn.bias
+ )
+
+ if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
+ modules = []
+ modules.append(conv)
+ relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
+ modules.append(relu)
+ conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
+ conv_relu.train(self.training)
+ return conv_relu
+ else:
+ conv.train(self.training)
+ return conv
+
+class ConvBn1d(_ConvBnNd, nn.Conv1d):
+ r"""
+ A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv1d` and
+ :class:`torch.nn.BatchNorm1d`.
+
+ Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
+ to default.
+
+ Attributes:
+ freeze_bn:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_BN_MODULE = nn.BatchNorm1d
+ _FLOAT_RELU_MODULE = None
+ _FLOAT_MODULE = nni.ConvBn1d
+ _FLOAT_CONV_MODULE = nn.Conv1d
+
+ def __init__(self,
+ # Conv1d args
+ in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=None,
+ padding_mode='zeros',
+ # BatchNorm1d args
+ # num_features: out_channels
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None):
+ kernel_size = _single(kernel_size)
+ stride = _single(stride)
+ padding = _single(padding)
+ dilation = _single(dilation)
+ _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, False, _single(0), groups, bias, padding_mode,
+ eps, momentum, freeze_bn, qconfig, dim=1)
+
+class ConvBnReLU1d(ConvBn1d):
+ r"""
+ A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv1d` and
+ :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
+
+ Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
+ default.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ # base class defines _FLOAT_MODULE as "ConvBn1d"
+ _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment]
+ _FLOAT_CONV_MODULE = nn.Conv1d
+ _FLOAT_BN_MODULE = nn.BatchNorm1d
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
+ # module class after fusing bn into conv
+ _FUSED_FLOAT_MODULE = nni.ConvReLU1d
+
+ def __init__(self,
+ # Conv1d args
+ in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=None,
+ padding_mode='zeros',
+ # BatchNorm1d args
+ # num_features: out_channels
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias,
+ padding_mode, eps, momentum,
+ freeze_bn,
+ qconfig)
+
+ def forward(self, input):
+ return F.relu(ConvBn1d._forward(self, input))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvBnReLU1d, cls).from_float(mod)
+
+class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
+ r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
+ FakeQuantize modules for weight for
+ quantization aware training.
+
+ We combined the interface of :class:`~torch.nn.Conv1d` and
+ :class:`~torch.nn.BatchNorm1d`.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvReLU1d
+ _FLOAT_CONV_MODULE = nn.Conv1d
+ _FLOAT_BN_MODULE = None
+ _FLOAT_RELU_MODULE = nn.ReLU
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros',
+ qconfig=None):
+ super(ConvReLU1d, self).__init__(in_channels, out_channels, kernel_size,
+ stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias, padding_mode=padding_mode,
+ qconfig=qconfig)
+ assert qconfig, 'qconfig must be provided for QAT module'
+ self.qconfig = qconfig
+ self.weight_fake_quant = self.qconfig.weight()
+
+ def forward(self, input):
+ return F.relu(
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvReLU1d, cls).from_float(mod)
+
+class ConvBn2d(_ConvBnNd, nn.Conv2d):
+ r"""
+ A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv2d` and
+ :class:`torch.nn.BatchNorm2d`.
+
+ Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
+ to default.
+
+ Attributes:
+ freeze_bn:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvBn2d
+ _FLOAT_CONV_MODULE = nn.Conv2d
+ _FLOAT_BN_MODULE = nn.BatchNorm2d
+ _FLOAT_RELU_MODULE = None
+
+ def __init__(self,
+ # ConvNd args
+ in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=None,
+ padding_mode='zeros',
+ # BatchNorm2d args
+ # num_features: out_channels
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None):
+ kernel_size = _pair(kernel_size)
+ stride = _pair(stride)
+ padding = _pair(padding)
+ dilation = _pair(dilation)
+ _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, False, _pair(0), groups, bias, padding_mode,
+ eps, momentum, freeze_bn, qconfig, dim=2)
+
+class ConvBnReLU2d(ConvBn2d):
+ r"""
+ A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv2d` and
+ :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
+
+ Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
+ default.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ # base class defines _FLOAT_MODULE as "ConvBn2d"
+ _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment]
+ _FLOAT_CONV_MODULE = nn.Conv2d
+ _FLOAT_BN_MODULE = nn.BatchNorm2d
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
+ # module class after fusing bn into conv
+ _FUSED_FLOAT_MODULE = nni.ConvReLU2d
+
+ def __init__(self,
+ # Conv2d args
+ in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=None,
+ padding_mode='zeros',
+ # BatchNorm2d args
+ # num_features: out_channels
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None):
+ super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias,
+ padding_mode, eps, momentum,
+ freeze_bn,
+ qconfig)
+
+ def forward(self, input):
+ return F.relu(ConvBn2d._forward(self, input))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvBnReLU2d, cls).from_float(mod)
+
+class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
+ r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
+ FakeQuantize modules for weight for
+ quantization aware training.
+
+ We combined the interface of :class:`~torch.nn.Conv2d` and
+ :class:`~torch.nn.BatchNorm2d`.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvReLU2d
+ _FLOAT_CONV_MODULE = nn.Conv2d
+ _FLOAT_BN_MODULE = None
+ _FLOAT_RELU_MODULE = nn.ReLU
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros',
+ qconfig=None):
+ super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
+ stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias, padding_mode=padding_mode,
+ qconfig=qconfig)
+ assert qconfig, 'qconfig must be provided for QAT module'
+ self.qconfig = qconfig
+ self.weight_fake_quant = self.qconfig.weight()
+
+ def forward(self, input):
+ return F.relu(
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvReLU2d, cls).from_float(mod)
+
+class ConvBn3d(_ConvBnNd, nn.Conv3d):
+ r"""
+ A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv3d` and
+ :class:`torch.nn.BatchNorm3d`.
+
+ Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
+ to default.
+
+ Attributes:
+ freeze_bn:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvBn3d
+ _FLOAT_CONV_MODULE = nn.Conv3d
+ _FLOAT_BN_MODULE = nn.BatchNorm3d
+ _FLOAT_RELU_MODULE = None
+
+ def __init__(
+ self,
+ # ConvNd args
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=None,
+ padding_mode="zeros",
+ # BatchNorm3d args
+ # num_features: out_channels
+ eps=1e-05,
+ momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None,
+ ):
+ kernel_size = _triple(kernel_size)
+ stride = _triple(stride)
+ padding = _triple(padding)
+ dilation = _triple(dilation)
+ _ConvBnNd.__init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ False,
+ _triple(0),
+ groups,
+ bias,
+ padding_mode,
+ eps,
+ momentum,
+ freeze_bn,
+ qconfig,
+ dim=3,
+ )
+
+class ConvBnReLU3d(ConvBn3d):
+ r"""
+ A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
+ attached with FakeQuantize modules for weight,
+ used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Conv3d` and
+ :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
+
+ Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
+ default.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment]
+ _FLOAT_CONV_MODULE = nn.Conv3d
+ _FLOAT_BN_MODULE = nn.BatchNorm3d
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
+ # module class after fusing bn into conv
+ _FUSED_FLOAT_MODULE = nni.ConvReLU3d
+
+ def __init__(
+ self,
+ # Conv3d args
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=None,
+ padding_mode="zeros",
+ # BatchNorm3d args
+ # num_features: out_channels
+ eps=1e-05,
+ momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None,
+ ):
+ super(ConvBnReLU3d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ eps,
+ momentum,
+ freeze_bn,
+ qconfig,
+ )
+
+ def forward(self, input):
+ return F.relu(ConvBn3d._forward(self, input))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvBnReLU3d, cls).from_float(mod)
+
+class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
+ r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
+ FakeQuantize modules for weight for
+ quantization aware training.
+
+ We combined the interface of :class:`~torch.nn.Conv3d` and
+ :class:`~torch.nn.BatchNorm3d`.
+
+ Attributes:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ _FLOAT_MODULE = nni.ConvReLU3d
+ _FLOAT_CONV_MODULE = nn.Conv3d
+ _FLOAT_BN_MODULE = None
+ _FLOAT_RELU_MODULE = nn.ReLU
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode="zeros",
+ qconfig=None,
+ ):
+ super(ConvReLU3d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode=padding_mode,
+ qconfig=qconfig,
+ )
+ assert qconfig, "qconfig must be provided for QAT module"
+ self.qconfig = qconfig
+ self.weight_fake_quant = self.qconfig.weight()
+
+ def forward(self, input):
+ return F.relu(
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+ )
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(ConvReLU3d, cls).from_float(mod)
+
+def update_bn_stats(mod):
+ if type(mod) in set(
+ [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
+ ):
+ mod.update_bn_stats()
+
+def freeze_bn_stats(mod):
+ if type(mod) in set(
+ [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
+ ):
+ mod.freeze_bn_stats()
diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
new file mode 100644
index 0000000..f19dbd9
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
@@ -0,0 +1,167 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.parameter import Parameter
+from torch.nn.utils.fusion import fuse_linear_bn_weights
+
+
+class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
+ r"""
+ A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
+ with FakeQuantize modules for weight, used in quantization aware training.
+
+ We combined the interface of :class:`torch.nn.Linear` and
+ :class:torch.nn.BatchNorm1d`.
+
+ Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
+ to default.
+
+ Attributes:
+ freeze_bn:
+ weight_fake_quant: fake quant module for weight
+
+ """
+ def __init__(self,
+ # Linear args
+ in_features, out_features, bias=True,
+ # BatchNorm1d args
+ # num_features: out_features
+ eps=1e-05, momentum=0.1,
+ # affine: True
+ # track_running_stats: True
+ # Args for this module
+ freeze_bn=False,
+ qconfig=None):
+ nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
+ assert qconfig, 'qconfig must be provded for QAT module'
+ self.qconfig = qconfig
+ self.freeze_bn = freeze_bn if self.training else True
+ self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
+ self.weight_fake_quant = self.qconfig.weight()
+ if bias:
+ self.bias = Parameter(torch.empty(out_features))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_bn_parameters()
+
+ # this needs to be called after reset_bn_parameters,
+ # as they modify the same state
+ if self.training:
+ if freeze_bn:
+ self.freeze_bn_stats()
+ else:
+ self.update_bn_stats()
+ else:
+ self.freeze_bn_stats()
+
+ def reset_running_stats(self):
+ self.bn.reset_running_stats()
+
+ def reset_bn_parameters(self):
+ self.bn.reset_running_stats()
+ init.uniform_(self.bn.weight)
+ init.zeros_(self.bn.bias)
+
+ def reset_parameters(self):
+ super(LinearBn1d, self).reset_parameters()
+
+ def update_bn_stats(self):
+ self.freeze_bn = False
+ self.bn.training = True
+ return self
+
+ def freeze_bn_stats(self):
+ self.freeze_bn = True
+ self.bn.training = False
+ return self
+
+ def forward(self, input):
+ assert self.bn.running_var is not None
+
+ # Scale the linear weights by BN's running statistics to reduce
+ # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
+ # for motivation.
+ #
+ # Instead of
+ #
+ # x1 = F.linear(x0, fq(w), b)
+ # x2 = self.bn(x1)
+ #
+ # We have
+ #
+ # # scale the weight by previous batch's running statistics
+ # scale_factor = bn.w / bn.running_std_from_prev_batch
+ # # do the linear transformation without bias
+ # x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
+ # # reverse the scaling and add original bias
+ # x1_orig = x1_scaled / scale_factor + b
+ # x2 = self.bn(x1_orig)
+
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+ scale_factor = self.bn.weight / running_std
+ weight_shape = [1] * len(self.weight.shape)
+ weight_shape[0] = -1
+ bias_shape = [1] * len(self.weight.shape)
+ bias_shape[1] = -1
+ scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+ if self.bias is not None:
+ zero_bias = torch.zeros_like(self.bias)
+ else:
+ zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
+ linear_out = F.linear(input, scaled_weight, zero_bias)
+ linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
+ if self.bias is not None:
+ linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
+ bn_out = self.bn(linear_out_orig)
+ return bn_out
+
+ def train(self, mode=True):
+ """
+ Batchnorm's training behavior is using the self.training flag. Prevent
+ changing it if BN is frozen. This makes sure that calling `model.train()`
+ on a model with a frozen BN will behave properly.
+ """
+ self.training = mode
+ if not self.freeze_bn:
+ for module in self.children():
+ module.train(mode)
+ return self
+
+ @classmethod
+ def from_float(cls, mod):
+ r"""Create a qat module from a float module or qparams_dict
+
+ Args: `mod' a float module, either produced by torch.ao.quantization
+ utilities or directly from user
+ """
+ assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
+ '.from_float only works for ' + nni.LinearBn1d.__name__
+ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+ assert mod.qconfig, 'Input float module must have a valid config'
+ qconfig = mod.qconfig
+ linear, bn = mod[0], mod[1]
+ qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
+ bn.eps, bn.momentum,
+ False, qconfig)
+ qat_linearbn.weight = linear.weight
+ qat_linearbn.bias = linear.bias
+ qat_linearbn.bn.weight = bn.weight
+ qat_linearbn.bn.bias = bn.bias
+ qat_linearbn.bn.running_mean = bn.running_mean
+ qat_linearbn.bn.running_var = bn.running_var
+ qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
+ return qat_linearbn
+
+ def to_float(self):
+ linear = torch.nn.Linear(self.in_features, self.out_features)
+ linear.weight, linear.bias = fuse_linear_bn_weights(
+ self.weight,
+ self.bias,
+ self.bn.running_mean,
+ self.bn.running_var,
+ self.bn.eps,
+ self.bn.weight,
+ self.bn.bias)
+ return linear
diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
new file mode 100644
index 0000000..1c77965
--- /dev/null
+++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
@@ -0,0 +1,48 @@
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+
+class LinearReLU(nnqat.Linear, nni._FusedModule):
+ r"""
+ A LinearReLU module fused from Linear and ReLU modules, attached with
+ FakeQuantize modules for weight, used in
+ quantization aware training.
+
+ We adopt the same interface as :class:`torch.nn.Linear`.
+
+ Similar to `torch.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
+ default.
+
+ Attributes:
+ weight: fake quant module for weight
+
+ Examples::
+
+ >>> # xdoctest: +SKIP
+ >>> m = nn.qat.LinearReLU(20, 30)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 30])
+ """
+ _FLOAT_MODULE = nni.LinearReLU
+
+ def __init__(self, in_features, out_features, bias=True,
+ qconfig=None):
+ super(LinearReLU, self).__init__(in_features, out_features, bias, qconfig)
+
+ def forward(self, input):
+ return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(LinearReLU, cls).from_float(mod)
+
+ def to_float(self):
+ linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
+ linear.weight = torch.nn.Parameter(self.weight.detach())
+ if self.bias is not None:
+ linear.bias = torch.nn.Parameter(self.bias.detach())
+ relu = torch.nn.ReLU()
+ return torch.nn.intrinsic.LinearReLU(linear, relu)
diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py
index c0966e0..234048d 100644
--- a/torch/ao/nn/quantized/modules/conv.py
+++ b/torch/ao/nn/quantized/modules/conv.py
@@ -7,7 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
from torch._ops import ops
from torch.nn.common_types import _size_1_t
diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py
index 85b0f8e..77adcd1 100644
--- a/torch/ao/nn/quantized/modules/linear.py
+++ b/torch/ao/nn/quantized/modules/linear.py
@@ -3,7 +3,7 @@
import torch.nn as nn
import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index 7e9ffcc..321ec09 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -9,7 +9,7 @@
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
diff --git a/torch/ao/ns/fx/weight_utils.py b/torch/ao/ns/fx/weight_utils.py
index 553e385..e02d464 100644
--- a/torch/ao/ns/fx/weight_utils.py
+++ b/torch/ao/ns/fx/weight_utils.py
@@ -3,7 +3,7 @@
import torch.nn.functional as F
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.quantized as nnq
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat as nnqat
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index baf4686..bc6f678 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -3,7 +3,7 @@
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.intrinsic as nni
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from collections import namedtuple
diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py
index 15d291b..5623d78 100644
--- a/torch/ao/quantization/quantization_mappings.py
+++ b/torch/ao/quantization/quantization_mappings.py
@@ -7,7 +7,7 @@
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.reference as nnqr
import torch.ao.nn.quantized.dynamic as nnqd
diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py
index f0be587..fbc89c1 100644
--- a/torch/nn/intrinsic/__init__.py
+++ b/torch/nn/intrinsic/__init__.py
@@ -13,8 +13,9 @@
from torch.ao.nn.intrinsic import LinearBn1d
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
-# Include the `module` in case user imports from it directly
-from . import modules
+# Include the subpackages in case user imports from it directly
+from . import modules # noqa: F401
+from . import qat # noqa: F401
__all__ = [
'ConvBn1d',
diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py
index 6a6f4c1..ccd79bb 100644
--- a/torch/nn/intrinsic/qat/modules/conv_fused.py
+++ b/torch/nn/intrinsic/qat/modules/conv_fused.py
@@ -1,828 +1,37 @@
-import math
-import torch
-import torch.nn as nn
-import torch.ao.nn.intrinsic as nni
-import torch.ao.nn.qat as nnqat
-import torch.nn.functional as F
-from torch.nn import init
-from torch.nn.utils import fuse_conv_bn_weights
-from torch.nn.modules.utils import _single, _pair, _triple
-from torch.nn.parameter import Parameter
-from typing import TypeVar
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
-__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
- 'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
-_BN_CLASS_MAP = {
- 1: nn.BatchNorm1d,
- 2: nn.BatchNorm2d,
- 3: nn.BatchNorm3d,
-}
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
+__all__ = [
+ # Modules
+ 'ConvBn1d',
+ 'ConvBnReLU1d',
+ 'ConvReLU1d',
+ 'ConvBn2d',
+ 'ConvBnReLU2d',
+ 'ConvReLU2d',
+ 'ConvBn3d',
+ 'ConvBnReLU3d',
+ 'ConvReLU3d',
+ # Utilities
+ 'freeze_bn_stats',
+ 'update_bn_stats',
+]
-MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
-
-
-class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
-
- _version = 2
- _FLOAT_MODULE = MOD
-
- def __init__(self,
- # ConvNd args
- in_channels, out_channels, kernel_size, stride,
- padding, dilation, transposed, output_padding,
- groups,
- bias,
- padding_mode,
- # BatchNormNd args
- # num_features: out_channels
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None,
- dim=2):
- nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
- stride, padding, dilation, transposed,
- output_padding, groups, False, padding_mode)
- assert qconfig, 'qconfig must be provided for QAT module'
- self.qconfig = qconfig
- self.freeze_bn = freeze_bn if self.training else True
- self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
- self.weight_fake_quant = self.qconfig.weight()
- if bias:
- self.bias = Parameter(torch.empty(out_channels))
- else:
- self.register_parameter('bias', None)
- self.reset_bn_parameters()
-
- # this needs to be called after reset_bn_parameters,
- # as they modify the same state
- if self.training:
- if freeze_bn:
- self.freeze_bn_stats()
- else:
- self.update_bn_stats()
- else:
- self.freeze_bn_stats()
-
- self._enable_slow_path_for_better_numerical_stability = False
-
- def reset_running_stats(self):
- self.bn.reset_running_stats()
-
- def reset_bn_parameters(self):
- self.bn.reset_running_stats()
- init.uniform_(self.bn.weight)
- init.zeros_(self.bn.bias)
- # note: below is actully for conv, not BN
- if self.bias is not None:
- fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in)
- init.uniform_(self.bias, -bound, bound)
-
- def reset_parameters(self):
- super(_ConvBnNd, self).reset_parameters()
-
- def update_bn_stats(self):
- self.freeze_bn = False
- self.bn.training = True
- return self
-
- def freeze_bn_stats(self):
- self.freeze_bn = True
- self.bn.training = False
- return self
-
- def _forward(self, input):
- if self._enable_slow_path_for_better_numerical_stability:
- return self._forward_slow(input)
- return self._forward_approximate(input)
-
- def _forward_approximate(self, input):
- """Approximated method to fuse conv and bn. It requires only one forward pass.
- conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
- """
- assert self.bn.running_var is not None
- running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
- scale_factor = self.bn.weight / running_std
- weight_shape = [1] * len(self.weight.shape)
- weight_shape[0] = -1
- bias_shape = [1] * len(self.weight.shape)
- bias_shape[1] = -1
- scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
- # using zero bias here since the bias for original conv
- # will be added later
- if self.bias is not None:
- zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
- else:
- zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
- conv = self._conv_forward(input, scaled_weight, zero_bias)
- conv_orig = conv / scale_factor.reshape(bias_shape)
- if self.bias is not None:
- conv_orig = conv_orig + self.bias.reshape(bias_shape)
- conv = self.bn(conv_orig)
- return conv
-
- def _forward_slow(self, input):
- """
- A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
- It requires two forward passes but handles the case bn.weight == 0
-
- Conv: Y = WX + B_c
- Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
-
- Batch statistics:
- mean_Y = Y.mean()
- = Y0.mean() + B_c
- var_Y = (Y - mean_Y)^2.mean()
- = (Y0 - Y0.mean())^2.mean()
- BN (r: bn.weight, beta: bn.bias):
- Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
- = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
-
- Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
- Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
- = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
-
- Fused Conv BN inference (running_std = sqrt(running_var + eps)):
- Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
-
- QAT with fused conv bn:
- Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
- = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
- Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
- """
-
- assert self.bn.running_var is not None
- assert self.bn.running_mean is not None
-
- # using zero bias here since the bias for original conv
- # will be added later
- zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.dtype)
-
- weight_shape = [1] * len(self.weight.shape)
- weight_shape[0] = -1
- bias_shape = [1] * len(self.weight.shape)
- bias_shape[1] = -1
-
- if self.bn.training:
- # needed to compute batch mean/std
- conv_out = self._conv_forward(input, self.weight, zero_bias)
- # update bn statistics
- with torch.no_grad():
- conv_out_bias = (
- conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
- )
- self.bn(conv_out_bias)
-
- # fused conv + bn without bias using bn running statistics
- running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
- scale_factor = self.bn.weight / running_std
- scaled_weight = self.weight_fake_quant(
- self.weight * scale_factor.reshape(weight_shape)
- )
- # fused conv without bias for inference: (r * W / running_std) * X
- conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
-
- if self.bn.training:
- avg_dims = [0] + list(range(2, len(self.weight.shape)))
- batch_mean = conv_out.mean(avg_dims)
- batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
- avg_dims
- )
- batch_std = torch.sqrt(batch_var + self.bn.eps)
-
- # scale to use batch std in training mode
- # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
- unscale_factor = running_std / batch_std
- conv_bn *= unscale_factor.reshape(bias_shape)
-
- fused_mean = batch_mean
- fused_std = batch_std
- else:
- fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
- fused_std = running_std
-
- # fused bias = beta - r * mean / std
- fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
- conv_bn += fused_bias.reshape(bias_shape)
-
- # HACK to let conv bias particpiate in loss to avoid DDP error (parameters
- # were not used in producing loss)
- if self.bias is not None:
- conv_bn += (self.bias - self.bias).reshape(bias_shape)
-
- return conv_bn
-
- def extra_repr(self):
- # TODO(jerryzh): extend
- return super(_ConvBnNd, self).extra_repr()
-
- def forward(self, input):
- return self._forward(input)
-
- def train(self, mode=True):
- """
- Batchnorm's training behavior is using the self.training flag. Prevent
- changing it if BN is frozen. This makes sure that calling `model.train()`
- on a model with a frozen BN will behave properly.
- """
- self.training = mode
- if not self.freeze_bn:
- for module in self.children():
- module.train(mode)
- return self
-
- # ===== Serialization version history =====
- #
- # Version 1/None
- # self
- # |--- weight : Tensor
- # |--- bias : Tensor
- # |--- gamma : Tensor
- # |--- beta : Tensor
- # |--- running_mean : Tensor
- # |--- running_var : Tensor
- # |--- num_batches_tracked : Tensor
- #
- # Version 2
- # self
- # |--- weight : Tensor
- # |--- bias : Tensor
- # |--- bn : Module
- # |--- weight : Tensor (moved from v1.self.gamma)
- # |--- bias : Tensor (moved from v1.self.beta)
- # |--- running_mean : Tensor (moved from v1.self.running_mean)
- # |--- running_var : Tensor (moved from v1.self.running_var)
- # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
- version = local_metadata.get('version', None)
- if version is None or version == 1:
- # BN related parameters and buffers were moved into the BN module for v2
- v2_to_v1_names = {
- 'bn.weight': 'gamma',
- 'bn.bias': 'beta',
- 'bn.running_mean': 'running_mean',
- 'bn.running_var': 'running_var',
- 'bn.num_batches_tracked': 'num_batches_tracked',
- }
- for v2_name, v1_name in v2_to_v1_names.items():
- if prefix + v1_name in state_dict:
- state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
- state_dict.pop(prefix + v1_name)
- elif prefix + v2_name in state_dict:
- # there was a brief period where forward compatibility
- # for this module was broken (between
- # https://github.com/pytorch/pytorch/pull/38478
- # and https://github.com/pytorch/pytorch/pull/38820)
- # and modules emitted the v2 state_dict format while
- # specifying that version == 1. This patches the forward
- # compatibility issue by allowing the v2 style entries to
- # be used.
- pass
- elif strict:
- missing_keys.append(prefix + v2_name)
-
- super(_ConvBnNd, self)._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
-
- @classmethod
- def from_float(cls, mod):
- r"""Create a qat module from a float module or qparams_dict
-
- Args: `mod` a float module, either produced by torch.ao.quantization utilities
- or directly from user
- """
- # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
- # has no __name__ (code is fine though)
- assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
- cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
- assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
- assert mod.qconfig, 'Input float module must have a valid qconfig'
- qconfig = mod.qconfig
- conv, bn = mod[0], mod[1]
- qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
- conv.stride, conv.padding, conv.dilation,
- conv.groups, conv.bias is not None,
- conv.padding_mode,
- bn.eps, bn.momentum,
- False,
- qconfig)
- qat_convbn.weight = conv.weight
- qat_convbn.bias = conv.bias
- qat_convbn.bn.weight = bn.weight
- qat_convbn.bn.bias = bn.bias
- qat_convbn.bn.running_mean = bn.running_mean
- qat_convbn.bn.running_var = bn.running_var
- # mypy error: Cannot determine type of 'num_batches_tracked'
- qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type]
- return qat_convbn
-
- def to_float(self):
- cls = type(self)
- conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
- self.in_channels,
- self.out_channels,
- self.kernel_size,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- self.bias is not None,
- self.padding_mode)
- conv.weight = torch.nn.Parameter(self.weight.detach())
- if self.bias is not None:
- conv.bias = torch.nn.Parameter(self.bias.detach())
-
- if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
- # fuse bn into conv
- conv.weight, conv.bias = fuse_conv_bn_weights(
- conv.weight,
- conv.bias,
- self.bn.running_mean,
- self.bn.running_var,
- self.bn.eps,
- self.bn.weight,
- self.bn.bias
- )
-
- if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
- modules = []
- modules.append(conv)
- relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
- modules.append(relu)
- conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
- conv_relu.train(self.training)
- return conv_relu
- else:
- conv.train(self.training)
- return conv
-
-class ConvBn1d(_ConvBnNd, nn.Conv1d):
- r"""
- A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv1d` and
- :class:`torch.nn.BatchNorm1d`.
-
- Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
- to default.
-
- Attributes:
- freeze_bn:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_BN_MODULE = nn.BatchNorm1d
- _FLOAT_RELU_MODULE = None
- _FLOAT_MODULE = nni.ConvBn1d
- _FLOAT_CONV_MODULE = nn.Conv1d
-
- def __init__(self,
- # Conv1d args
- in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=None,
- padding_mode='zeros',
- # BatchNorm1d args
- # num_features: out_channels
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None):
- kernel_size = _single(kernel_size)
- stride = _single(stride)
- padding = _single(padding)
- dilation = _single(dilation)
- _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
- padding, dilation, False, _single(0), groups, bias, padding_mode,
- eps, momentum, freeze_bn, qconfig, dim=1)
-
-class ConvBnReLU1d(ConvBn1d):
- r"""
- A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv1d` and
- :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
-
- Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
- default.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- # base class defines _FLOAT_MODULE as "ConvBn1d"
- _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment]
- _FLOAT_CONV_MODULE = nn.Conv1d
- _FLOAT_BN_MODULE = nn.BatchNorm1d
- _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
- # module class after fusing bn into conv
- _FUSED_FLOAT_MODULE = nni.ConvReLU1d
-
- def __init__(self,
- # Conv1d args
- in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=None,
- padding_mode='zeros',
- # BatchNorm1d args
- # num_features: out_channels
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None):
- super().__init__(in_channels, out_channels, kernel_size, stride,
- padding, dilation, groups, bias,
- padding_mode, eps, momentum,
- freeze_bn,
- qconfig)
-
- def forward(self, input):
- return F.relu(ConvBn1d._forward(self, input))
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvBnReLU1d, cls).from_float(mod)
-
-class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
- r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
- FakeQuantize modules for weight for
- quantization aware training.
-
- We combined the interface of :class:`~torch.nn.Conv1d` and
- :class:`~torch.nn.BatchNorm1d`.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvReLU1d
- _FLOAT_CONV_MODULE = nn.Conv1d
- _FLOAT_BN_MODULE = None
- _FLOAT_RELU_MODULE = nn.ReLU
-
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=True, padding_mode='zeros',
- qconfig=None):
- super(ConvReLU1d, self).__init__(in_channels, out_channels, kernel_size,
- stride=stride, padding=padding, dilation=dilation,
- groups=groups, bias=bias, padding_mode=padding_mode,
- qconfig=qconfig)
- assert qconfig, 'qconfig must be provided for QAT module'
- self.qconfig = qconfig
- self.weight_fake_quant = self.qconfig.weight()
-
- def forward(self, input):
- return F.relu(
- self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvReLU1d, cls).from_float(mod)
-
-class ConvBn2d(_ConvBnNd, nn.Conv2d):
- r"""
- A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv2d` and
- :class:`torch.nn.BatchNorm2d`.
-
- Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
- to default.
-
- Attributes:
- freeze_bn:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvBn2d
- _FLOAT_CONV_MODULE = nn.Conv2d
- _FLOAT_BN_MODULE = nn.BatchNorm2d
- _FLOAT_RELU_MODULE = None
-
- def __init__(self,
- # ConvNd args
- in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=None,
- padding_mode='zeros',
- # BatchNorm2d args
- # num_features: out_channels
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None):
- kernel_size = _pair(kernel_size)
- stride = _pair(stride)
- padding = _pair(padding)
- dilation = _pair(dilation)
- _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
- padding, dilation, False, _pair(0), groups, bias, padding_mode,
- eps, momentum, freeze_bn, qconfig, dim=2)
-
-class ConvBnReLU2d(ConvBn2d):
- r"""
- A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv2d` and
- :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
-
- Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
- default.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- # base class defines _FLOAT_MODULE as "ConvBn2d"
- _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment]
- _FLOAT_CONV_MODULE = nn.Conv2d
- _FLOAT_BN_MODULE = nn.BatchNorm2d
- _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
- # module class after fusing bn into conv
- _FUSED_FLOAT_MODULE = nni.ConvReLU2d
-
- def __init__(self,
- # Conv2d args
- in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=None,
- padding_mode='zeros',
- # BatchNorm2d args
- # num_features: out_channels
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None):
- super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
- padding, dilation, groups, bias,
- padding_mode, eps, momentum,
- freeze_bn,
- qconfig)
-
- def forward(self, input):
- return F.relu(ConvBn2d._forward(self, input))
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvBnReLU2d, cls).from_float(mod)
-
-class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
- r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
- FakeQuantize modules for weight for
- quantization aware training.
-
- We combined the interface of :class:`~torch.nn.Conv2d` and
- :class:`~torch.nn.BatchNorm2d`.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvReLU2d
- _FLOAT_CONV_MODULE = nn.Conv2d
- _FLOAT_BN_MODULE = None
- _FLOAT_RELU_MODULE = nn.ReLU
-
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1,
- bias=True, padding_mode='zeros',
- qconfig=None):
- super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
- stride=stride, padding=padding, dilation=dilation,
- groups=groups, bias=bias, padding_mode=padding_mode,
- qconfig=qconfig)
- assert qconfig, 'qconfig must be provided for QAT module'
- self.qconfig = qconfig
- self.weight_fake_quant = self.qconfig.weight()
-
- def forward(self, input):
- return F.relu(
- self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvReLU2d, cls).from_float(mod)
-
-class ConvBn3d(_ConvBnNd, nn.Conv3d):
- r"""
- A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv3d` and
- :class:`torch.nn.BatchNorm3d`.
-
- Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
- to default.
-
- Attributes:
- freeze_bn:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvBn3d
- _FLOAT_CONV_MODULE = nn.Conv3d
- _FLOAT_BN_MODULE = nn.BatchNorm3d
- _FLOAT_RELU_MODULE = None
-
- def __init__(
- self,
- # ConvNd args
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=None,
- padding_mode="zeros",
- # BatchNorm3d args
- # num_features: out_channels
- eps=1e-05,
- momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None,
- ):
- kernel_size = _triple(kernel_size)
- stride = _triple(stride)
- padding = _triple(padding)
- dilation = _triple(dilation)
- _ConvBnNd.__init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- False,
- _triple(0),
- groups,
- bias,
- padding_mode,
- eps,
- momentum,
- freeze_bn,
- qconfig,
- dim=3,
- )
-
-class ConvBnReLU3d(ConvBn3d):
- r"""
- A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
- attached with FakeQuantize modules for weight,
- used in quantization aware training.
-
- We combined the interface of :class:`torch.nn.Conv3d` and
- :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
-
- Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
- default.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment]
- _FLOAT_CONV_MODULE = nn.Conv3d
- _FLOAT_BN_MODULE = nn.BatchNorm3d
- _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
- # module class after fusing bn into conv
- _FUSED_FLOAT_MODULE = nni.ConvReLU3d
-
- def __init__(
- self,
- # Conv3d args
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=None,
- padding_mode="zeros",
- # BatchNorm3d args
- # num_features: out_channels
- eps=1e-05,
- momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None,
- ):
- super(ConvBnReLU3d, self).__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- padding_mode,
- eps,
- momentum,
- freeze_bn,
- qconfig,
- )
-
- def forward(self, input):
- return F.relu(ConvBn3d._forward(self, input))
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvBnReLU3d, cls).from_float(mod)
-
-class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
- r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
- FakeQuantize modules for weight for
- quantization aware training.
-
- We combined the interface of :class:`~torch.nn.Conv3d` and
- :class:`~torch.nn.BatchNorm3d`.
-
- Attributes:
- weight_fake_quant: fake quant module for weight
-
- """
- _FLOAT_MODULE = nni.ConvReLU3d
- _FLOAT_CONV_MODULE = nn.Conv3d
- _FLOAT_BN_MODULE = None
- _FLOAT_RELU_MODULE = nn.ReLU
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=True,
- padding_mode="zeros",
- qconfig=None,
- ):
- super(ConvReLU3d, self).__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- qconfig=qconfig,
- )
- assert qconfig, "qconfig must be provided for QAT module"
- self.qconfig = qconfig
- self.weight_fake_quant = self.qconfig.weight()
-
- def forward(self, input):
- return F.relu(
- self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
- )
-
- @classmethod
- def from_float(cls, mod):
- return super(ConvReLU3d, cls).from_float(mod)
-
-def update_bn_stats(mod):
- if type(mod) in set(
- [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
- ):
- mod.update_bn_stats()
-
-def freeze_bn_stats(mod):
- if type(mod) in set(
- [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
- ):
- mod.freeze_bn_stats()
+from torch.ao.nn.intrinsic.qat import ConvBn1d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU1d
+from torch.ao.nn.intrinsic.qat import ConvReLU1d
+from torch.ao.nn.intrinsic.qat import ConvBn2d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU2d
+from torch.ao.nn.intrinsic.qat import ConvReLU2d
+from torch.ao.nn.intrinsic.qat import ConvBn3d
+from torch.ao.nn.intrinsic.qat import ConvBnReLU3d
+from torch.ao.nn.intrinsic.qat import ConvReLU3d
+from torch.ao.nn.intrinsic.qat import freeze_bn_stats
+from torch.ao.nn.intrinsic.qat import update_bn_stats
diff --git a/torch/nn/intrinsic/qat/modules/linear_fused.py b/torch/nn/intrinsic/qat/modules/linear_fused.py
index f19dbd9..57fbf83 100644
--- a/torch/nn/intrinsic/qat/modules/linear_fused.py
+++ b/torch/nn/intrinsic/qat/modules/linear_fused.py
@@ -1,167 +1,15 @@
-import torch
-import torch.nn as nn
-import torch.ao.nn.intrinsic as nni
-import torch.nn.functional as F
-from torch.nn import init
-from torch.nn.parameter import Parameter
-from torch.nn.utils.fusion import fuse_linear_bn_weights
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
-class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
- r"""
- A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
- with FakeQuantize modules for weight, used in quantization aware training.
+__all__ = [
+ 'LinearBn1d',
+]
- We combined the interface of :class:`torch.nn.Linear` and
- :class:torch.nn.BatchNorm1d`.
-
- Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
- to default.
-
- Attributes:
- freeze_bn:
- weight_fake_quant: fake quant module for weight
-
- """
- def __init__(self,
- # Linear args
- in_features, out_features, bias=True,
- # BatchNorm1d args
- # num_features: out_features
- eps=1e-05, momentum=0.1,
- # affine: True
- # track_running_stats: True
- # Args for this module
- freeze_bn=False,
- qconfig=None):
- nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
- assert qconfig, 'qconfig must be provded for QAT module'
- self.qconfig = qconfig
- self.freeze_bn = freeze_bn if self.training else True
- self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
- self.weight_fake_quant = self.qconfig.weight()
- if bias:
- self.bias = Parameter(torch.empty(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_bn_parameters()
-
- # this needs to be called after reset_bn_parameters,
- # as they modify the same state
- if self.training:
- if freeze_bn:
- self.freeze_bn_stats()
- else:
- self.update_bn_stats()
- else:
- self.freeze_bn_stats()
-
- def reset_running_stats(self):
- self.bn.reset_running_stats()
-
- def reset_bn_parameters(self):
- self.bn.reset_running_stats()
- init.uniform_(self.bn.weight)
- init.zeros_(self.bn.bias)
-
- def reset_parameters(self):
- super(LinearBn1d, self).reset_parameters()
-
- def update_bn_stats(self):
- self.freeze_bn = False
- self.bn.training = True
- return self
-
- def freeze_bn_stats(self):
- self.freeze_bn = True
- self.bn.training = False
- return self
-
- def forward(self, input):
- assert self.bn.running_var is not None
-
- # Scale the linear weights by BN's running statistics to reduce
- # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
- # for motivation.
- #
- # Instead of
- #
- # x1 = F.linear(x0, fq(w), b)
- # x2 = self.bn(x1)
- #
- # We have
- #
- # # scale the weight by previous batch's running statistics
- # scale_factor = bn.w / bn.running_std_from_prev_batch
- # # do the linear transformation without bias
- # x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
- # # reverse the scaling and add original bias
- # x1_orig = x1_scaled / scale_factor + b
- # x2 = self.bn(x1_orig)
-
- running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
- scale_factor = self.bn.weight / running_std
- weight_shape = [1] * len(self.weight.shape)
- weight_shape[0] = -1
- bias_shape = [1] * len(self.weight.shape)
- bias_shape[1] = -1
- scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
- if self.bias is not None:
- zero_bias = torch.zeros_like(self.bias)
- else:
- zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
- linear_out = F.linear(input, scaled_weight, zero_bias)
- linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
- if self.bias is not None:
- linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
- bn_out = self.bn(linear_out_orig)
- return bn_out
-
- def train(self, mode=True):
- """
- Batchnorm's training behavior is using the self.training flag. Prevent
- changing it if BN is frozen. This makes sure that calling `model.train()`
- on a model with a frozen BN will behave properly.
- """
- self.training = mode
- if not self.freeze_bn:
- for module in self.children():
- module.train(mode)
- return self
-
- @classmethod
- def from_float(cls, mod):
- r"""Create a qat module from a float module or qparams_dict
-
- Args: `mod' a float module, either produced by torch.ao.quantization
- utilities or directly from user
- """
- assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
- '.from_float only works for ' + nni.LinearBn1d.__name__
- assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
- assert mod.qconfig, 'Input float module must have a valid config'
- qconfig = mod.qconfig
- linear, bn = mod[0], mod[1]
- qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
- bn.eps, bn.momentum,
- False, qconfig)
- qat_linearbn.weight = linear.weight
- qat_linearbn.bias = linear.bias
- qat_linearbn.bn.weight = bn.weight
- qat_linearbn.bn.bias = bn.bias
- qat_linearbn.bn.running_mean = bn.running_mean
- qat_linearbn.bn.running_var = bn.running_var
- qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
- return qat_linearbn
-
- def to_float(self):
- linear = torch.nn.Linear(self.in_features, self.out_features)
- linear.weight, linear.bias = fuse_linear_bn_weights(
- self.weight,
- self.bias,
- self.bn.running_mean,
- self.bn.running_var,
- self.bn.eps,
- self.bn.weight,
- self.bn.bias)
- return linear
+from torch.ao.nn.intrinsic.qat import LinearBn1d
diff --git a/torch/nn/intrinsic/qat/modules/linear_relu.py b/torch/nn/intrinsic/qat/modules/linear_relu.py
index 1c77965..45afc33 100644
--- a/torch/nn/intrinsic/qat/modules/linear_relu.py
+++ b/torch/nn/intrinsic/qat/modules/linear_relu.py
@@ -1,48 +1,15 @@
-import torch
-import torch.ao.nn.qat as nnqat
-import torch.ao.nn.intrinsic as nni
-import torch.nn.functional as F
+# flake8: noqa: F401
+r"""Intrinsic QAT Modules
-class LinearReLU(nnqat.Linear, nni._FusedModule):
- r"""
- A LinearReLU module fused from Linear and ReLU modules, attached with
- FakeQuantize modules for weight, used in
- quantization aware training.
+This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
+while adding an import statement here.
+"""
- We adopt the same interface as :class:`torch.nn.Linear`.
+__all__ = [
+ 'LinearReLU',
+]
- Similar to `torch.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
- default.
-
- Attributes:
- weight: fake quant module for weight
-
- Examples::
-
- >>> # xdoctest: +SKIP
- >>> m = nn.qat.LinearReLU(20, 30)
- >>> input = torch.randn(128, 20)
- >>> output = m(input)
- >>> print(output.size())
- torch.Size([128, 30])
- """
- _FLOAT_MODULE = nni.LinearReLU
-
- def __init__(self, in_features, out_features, bias=True,
- qconfig=None):
- super(LinearReLU, self).__init__(in_features, out_features, bias, qconfig)
-
- def forward(self, input):
- return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
-
- @classmethod
- def from_float(cls, mod):
- return super(LinearReLU, cls).from_float(mod)
-
- def to_float(self):
- linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
- linear.weight = torch.nn.Parameter(self.weight.detach())
- if self.bias is not None:
- linear.bias = torch.nn.Parameter(self.bias.detach())
- relu = torch.nn.ReLU()
- return torch.nn.intrinsic.LinearReLU(linear, relu)
+from torch.ao.nn.intrinsic.qat import LinearReLU
diff --git a/torch/nn/intrinsic/quantized/modules/bn_relu.py b/torch/nn/intrinsic/quantized/modules/bn_relu.py
index 62e4f4f..d8e6fcc 100644
--- a/torch/nn/intrinsic/quantized/modules/bn_relu.py
+++ b/torch/nn/intrinsic/quantized/modules/bn_relu.py
@@ -1,7 +1,7 @@
import torch
import torch.ao.nn.intrinsic
-import torch.nn.intrinsic.qat
+import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
diff --git a/torch/nn/intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py
index 022c97e..ba6691a 100644
--- a/torch/nn/intrinsic/quantized/modules/conv_relu.py
+++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py
@@ -1,7 +1,7 @@
import torch
import torch.ao.nn.intrinsic
-import torch.nn.intrinsic.qat
+import torch.ao.nn.intrinsic.qat
import torch.nn.functional as F
import torch.ao.nn.quantized as nnq
@@ -48,7 +48,7 @@
@classmethod
def from_float(cls, mod):
- if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU1d:
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
@@ -97,7 +97,7 @@
@classmethod
def from_float(cls, mod):
- if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU2d:
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
@@ -147,7 +147,7 @@
@classmethod
def from_float(cls, mod):
- if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU3d:
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,