add type annotations to torch.nn.quantized.modules.conv (#49702)

Summary:
closes gh-49700

No mypy issues were found in the first three entries deleted from `mypy.ini`:
```
[mypy-torch.nn.qat.modules.activations]
ignore_errors = True

[mypy-torch.nn.qat.modules.conv]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49702

Reviewed By: walterddr, zou3519

Differential Revision: D25767119

Pulled By: ezyang

fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a
diff --git a/mypy.ini b/mypy.ini
index 6c579ee..0c99a9c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -91,16 +91,7 @@
 [mypy-torch.nn.modules.pooling]
 ignore_errors = True
 
-[mypy-torch.nn.qat.modules.activations]
-ignore_errors = True
-
-[mypy-torch.nn.qat.modules.conv]
-ignore_errors = True
-
-[mypy-torch.nn.quantized.dynamic.modules.linear]
-ignore_errors = True
-
-[mypy-torch.nn.quantized.modules.conv]
+[mypy-torch.nn.parallel._functions]
 ignore_errors = True
 
 [mypy-torch._appdirs]
diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py
index 00ceba7..b3bc78f 100644
--- a/torch/nn/quantized/modules/conv.py
+++ b/torch/nn/quantized/modules/conv.py
@@ -1,7 +1,7 @@
 # coding=utf-8
 r"""Quantized convolution modules."""
 
-from typing import Optional, List
+from typing import Optional, List, TypeVar
 
 import torch
 import torch.nn as nn
@@ -16,11 +16,17 @@
 
 class _ConvNd(nn.Module):
 
-    def __init__(self, in_channels, out_channels, kernel_size, stride,
-                 padding, dilation,
-                 transposed, output_padding,
-                 groups, bias,
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
                  padding_mode='zeros'):
+        # All subclasses have this signature - See PR #49702s
+        raise NotImplementedError
+
+    def _init(self, in_channels, out_channels, kernel_size, stride,
+              padding, dilation,
+              transposed, output_padding,
+              groups, bias,
+              padding_mode='zeros'):
         super(_ConvNd, self).__init__()
         if padding_mode != 'zeros':
             raise NotImplementedError(
@@ -54,6 +60,15 @@
         self.scale = 1.0
         self.zero_point = 0
 
+    def set_weight_bias(self, qweight, bias_float):
+        raise NotImplementedError
+
+    def bias(self):
+        raise NotImplementedError
+
+    def _weight_bias(self):
+        raise NotImplementedError
+
     def extra_repr(self):
         s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
              ', stride={stride}, scale={scale}, zero_point={zero_point}')
@@ -155,7 +170,8 @@
         assert weight_post_process.dtype == torch.qint8, \
             'Weight observer must have a dtype of qint8'
         qweight = _quantize_weight(mod.weight.float(), weight_post_process)
-        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+        # the __init__ call used is the one from derived classes and not the one from _ConvNd
+        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,  # type: ignore[call-arg]
                     mod.stride, mod.padding, mod.dilation, mod.groups,
                     mod.bias is not None, mod.padding_mode)
         qconv.set_weight_bias(qweight, mod.bias)
@@ -233,7 +249,9 @@
         padding = _pair_from_first(padding)
         dilation = _pair_from_first(dilation)
 
-        super(Conv1d, self).__init__(
+        # Subclasses of _ConvNd needs to call _init rather than __init__. See
+        # discussion on PR #49702
+        super(Conv1d, self)._init(
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _single(0), groups, bias, padding_mode)
 
@@ -319,7 +337,9 @@
         stride = _pair(stride)
         padding = _pair(padding)
         dilation = _pair(dilation)
-        super(Conv2d, self).__init__(
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super(Conv2d, self)._init(
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _pair(0), groups, bias, padding_mode)
 
@@ -403,7 +423,9 @@
         stride = _triple(stride)
         padding = _triple(padding)
         dilation = _triple(dilation)
-        super(Conv3d, self).__init__(
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super(Conv3d, self)._init(
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             False, _triple(0), groups, bias, padding_mode)
 
@@ -450,15 +472,20 @@
         return cls.get_qconv(mod, activation_post_process)
 
 # === Transposed Convolutions ===
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
 
 class _ConvTransposeNd(_ConvNd):
+
+    _FLOAT_MODULE = MOD
+
     def __init__(self, in_channels, out_channels, kernel_size, stride,
                  padding, dilation, transposed, output_padding,
                  groups, bias, padding_mode):
         if padding_mode != 'zeros':
             raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
-
-        super(_ConvTransposeNd, self).__init__(
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super(_ConvTransposeNd, self)._init(
             in_channels, out_channels, kernel_size, stride,
             padding, dilation, transposed, output_padding,
             groups, bias, padding_mode)
@@ -477,9 +504,10 @@
             mod (Module): a float module, either produced by torch.quantization
               utilities or provided by the user
         """
-        assert type(mod) == cls._FLOAT_MODULE, \
-            ' nnq.' + cls.__name__ + '.from_float only works for ' + \
-            cls._FLOAT_MODULE.__name__
+        # derived classes override cls._FLOAT_MODULE attribute
+        msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
+              cls._FLOAT_MODULE.__name__
+        assert type(mod) == cls._FLOAT_MODULE, msg
         assert hasattr(mod, 'qconfig'), \
             'Input float module must have qconfig defined.'
         weight_post_process = mod.qconfig.weight()
@@ -488,7 +516,8 @@
         assert weight_post_process.dtype == torch.qint8, \
             'Weight observer must have a dtype of qint8'
         qweight = _quantize_weight(mod.weight.float(), weight_post_process)
-        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+        # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
+        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,  # type: ignore[call-arg]
                     mod.stride, mod.padding, mod.output_padding, mod.groups,
                     mod.bias is not None, mod.dilation, mod.padding_mode)
         qconv.set_weight_bias(qweight, mod.bias)