add type annotations to torch.nn.quantized.modules.conv (#49702)
Summary:
closes gh-49700
No mypy issues were found in the first three entries deleted from `mypy.ini`:
```
[mypy-torch.nn.qat.modules.activations]
ignore_errors = True
[mypy-torch.nn.qat.modules.conv]
ignore_errors = True
[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49702
Reviewed By: walterddr, zou3519
Differential Revision: D25767119
Pulled By: ezyang
fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a
diff --git a/mypy.ini b/mypy.ini
index 6c579ee..0c99a9c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -91,16 +91,7 @@
[mypy-torch.nn.modules.pooling]
ignore_errors = True
-[mypy-torch.nn.qat.modules.activations]
-ignore_errors = True
-
-[mypy-torch.nn.qat.modules.conv]
-ignore_errors = True
-
-[mypy-torch.nn.quantized.dynamic.modules.linear]
-ignore_errors = True
-
-[mypy-torch.nn.quantized.modules.conv]
+[mypy-torch.nn.parallel._functions]
ignore_errors = True
[mypy-torch._appdirs]
diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py
index 00ceba7..b3bc78f 100644
--- a/torch/nn/quantized/modules/conv.py
+++ b/torch/nn/quantized/modules/conv.py
@@ -1,7 +1,7 @@
# coding=utf-8
r"""Quantized convolution modules."""
-from typing import Optional, List
+from typing import Optional, List, TypeVar
import torch
import torch.nn as nn
@@ -16,11 +16,17 @@
class _ConvNd(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride,
- padding, dilation,
- transposed, output_padding,
- groups, bias,
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros'):
+ # All subclasses have this signature - See PR #49702s
+ raise NotImplementedError
+
+ def _init(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation,
+ transposed, output_padding,
+ groups, bias,
+ padding_mode='zeros'):
super(_ConvNd, self).__init__()
if padding_mode != 'zeros':
raise NotImplementedError(
@@ -54,6 +60,15 @@
self.scale = 1.0
self.zero_point = 0
+ def set_weight_bias(self, qweight, bias_float):
+ raise NotImplementedError
+
+ def bias(self):
+ raise NotImplementedError
+
+ def _weight_bias(self):
+ raise NotImplementedError
+
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}, scale={scale}, zero_point={zero_point}')
@@ -155,7 +170,8 @@
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
- qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+ # the __init__ call used is the one from derived classes and not the one from _ConvNd
+ qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)
@@ -233,7 +249,9 @@
padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation)
- super(Conv1d, self).__init__(
+ # Subclasses of _ConvNd needs to call _init rather than __init__. See
+ # discussion on PR #49702
+ super(Conv1d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)
@@ -319,7 +337,9 @@
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
- super(Conv2d, self).__init__(
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
+ # discussion on PR #49702
+ super(Conv2d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
@@ -403,7 +423,9 @@
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
- super(Conv3d, self).__init__(
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
+ # discussion on PR #49702
+ super(Conv3d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode)
@@ -450,15 +472,20 @@
return cls.get_qconv(mod, activation_post_process)
# === Transposed Convolutions ===
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
class _ConvTransposeNd(_ConvNd):
+
+ _FLOAT_MODULE = MOD
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode):
if padding_mode != 'zeros':
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
-
- super(_ConvTransposeNd, self).__init__(
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
+ # discussion on PR #49702
+ super(_ConvTransposeNd, self)._init(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode)
@@ -477,9 +504,10 @@
mod (Module): a float module, either produced by torch.quantization
utilities or provided by the user
"""
- assert type(mod) == cls._FLOAT_MODULE, \
- ' nnq.' + cls.__name__ + '.from_float only works for ' + \
- cls._FLOAT_MODULE.__name__
+ # derived classes override cls._FLOAT_MODULE attribute
+ msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
+ cls._FLOAT_MODULE.__name__
+ assert type(mod) == cls._FLOAT_MODULE, msg
assert hasattr(mod, 'qconfig'), \
'Input float module must have qconfig defined.'
weight_post_process = mod.qconfig.weight()
@@ -488,7 +516,8 @@
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
- qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+ # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
+ qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.output_padding, mod.groups,
mod.bias is not None, mod.dilation, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)