| import torch.nn as nn |
| from torch.nn.intrinsic import ConvReLU2d |
| |
| class Conv2d(nn.Conv2d): |
| r""" |
| A Conv2d module attached with FakeQuantize modules for weight, |
| used for quantization aware training. |
| |
| We adopt the same interface as `torch.nn.Conv2d`, please see |
| https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d |
| for documentation. |
| |
| Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to |
| default. |
| |
| Attributes: |
| weight_fake_quant: fake quant module for weight |
| """ |
| _FLOAT_MODULE = nn.Conv2d |
| |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| padding=0, dilation=1, groups=1, |
| bias=True, padding_mode='zeros', qconfig=None): |
| super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, |
| stride=stride, padding=padding, dilation=dilation, |
| groups=groups, bias=bias, padding_mode=padding_mode) |
| assert qconfig, 'qconfig must be provided for QAT module' |
| self.qconfig = qconfig |
| self.weight_fake_quant = qconfig.weight() |
| |
| def forward(self, input): |
| return self._conv_forward(input, self.weight_fake_quant(self.weight)) |
| |
| @classmethod |
| def from_float(cls, mod): |
| r"""Create a qat module from a float module or qparams_dict |
| |
| Args: `mod` a float module, either produced by torch.quantization utilities |
| or directly from user |
| """ |
| assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ |
| cls._FLOAT_MODULE.__name__ |
| assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' |
| assert mod.qconfig, 'Input float module must have a valid qconfig' |
| if type(mod) == ConvReLU2d: |
| mod = mod[0] |
| qconfig = mod.qconfig |
| qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, |
| stride=mod.stride, padding=mod.padding, dilation=mod.dilation, |
| groups=mod.groups, bias=mod.bias is not None, |
| padding_mode=mod.padding_mode, qconfig=qconfig) |
| qat_conv.weight = mod.weight |
| qat_conv.bias = mod.bias |
| return qat_conv |