Revert D16852280: Work around for bias quantization for conv and linear operators

Test Plan: revert-hammer

Differential Revision:
D16852280

Original commit changeset: 988f8ff91616

fbshipit-source-id: e2cf03e13dc8dcf0db22d43740d72fd8b069fd74
diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py
index e0230f2..8c97fe9 100644
--- a/test/test_quantized_nn_mods.py
+++ b/test/test_quantized_nn_mods.py
@@ -447,10 +447,7 @@
 
         # Smoke test to make sure the module actually runs
         quantized_float_conv(qX)
-        # Check that bias is quantized based on output scale
-        if use_bias:
-            qbias = torch.quantize_linear(float_conv.bias, quantized_float_conv.scale / 2**16, 0, torch.qint32)
-            self.assertEqual(quantized_float_conv.bias.dequantize(), qbias.dequantize())
+
         # Smoke test extra_repr
         str(quantized_float_conv)
 
diff --git a/torch/nn/_intrinsic/quantized/modules/conv_relu.py b/torch/nn/_intrinsic/quantized/modules/conv_relu.py
index 5bf6b17..2c14494 100644
--- a/torch/nn/_intrinsic/quantized/modules/conv_relu.py
+++ b/torch/nn/_intrinsic/quantized/modules/conv_relu.py
@@ -33,20 +33,14 @@
                                                                       self.padding,
                                                                       self.dilation,
                                                                       self.groups)
-        self.weight_scale = w.q_scale()
 
     def forward(self, input):
         # Temporarily using len(shape) instead of ndim due to JIT issue
         # https://github.com/pytorch/pytorch/issues/23890
         if len(input.shape) != 4:
             raise ValueError("Input shape must be `(N, C, H, W)`!")
-        # Temporary work around for bias
-        # see Issue:https://github.com/pytorch/pytorch/issues/23874
-        bias = self.bias
-        if bias is not None:
-            bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32)
         output = torch.ops.quantized.fbgemm_conv2d_relu(input.permute([0, 2, 3, 1]),
-                                                        self._packed_weight, bias,
+                                                        self._packed_weight, self.bias,
                                                         self.stride, self.padding,
                                                         self.dilation, self.groups,
                                                         float(self.scale), int(self.zero_point))
diff --git a/torch/nn/_intrinsic/quantized/modules/linear_relu.py b/torch/nn/_intrinsic/quantized/modules/linear_relu.py
index 0714e01..0511b70 100644
--- a/torch/nn/_intrinsic/quantized/modules/linear_relu.py
+++ b/torch/nn/_intrinsic/quantized/modules/linear_relu.py
@@ -26,13 +26,9 @@
         super(LinearReLU, self).__init__(in_features, out_features, bias)
 
     def forward(self, input):
-        bias = self.bias
-        if bias is not None:
-            bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32)
-
         Y_q = torch.ops.quantized.fbgemm_linear_relu(
             input, self._packed_weight,
-            bias,
+            self.bias,
             float(self.scale),
             int(self.zero_point))
         return Y_q
diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py
index 08b963f..7954014 100644
--- a/torch/nn/quantized/functional.py
+++ b/torch/nn/quantized/functional.py
@@ -54,8 +54,6 @@
     if zero_point is None:
         zero_point = input.q_zero_point()
     _packed_weight = torch.ops.quantized.fbgemm_linear_prepack(weight)
-    if bias is not None:
-        bias = torch.quantize_linear(bias.dequantize(), weight.q_scale() * input.q_scale(), 0, torch.qint32)
     return torch.ops.quantized.fbgemm_linear(input, _packed_weight, bias, scale,
                                              zero_point)
 
@@ -118,8 +116,6 @@
 
     prepacked_weight = torch.ops.quantized.fbgemm_conv_prepack(
         weight.permute([0, 2, 3, 1]), stride, padding, dilation, groups)
-    if bias is not None:
-        bias = torch.quantize_linear(bias.dequantize(), weight.q_scale() * input.q_scale(), 0, torch.qint32)
     return torch.ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]),
                                              prepacked_weight, bias,
                                              stride, padding, dilation,
diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py
index 418d61d..f23665d 100644
--- a/torch/nn/quantized/modules/conv.py
+++ b/torch/nn/quantized/modules/conv.py
@@ -83,15 +83,10 @@
             [out_channels, in_channels // self.groups, self.kernel_size[0],
                 self.kernel_size[1]],
             scale=1, zero_point=0, dtype=torch.qint8)
-        self.weight_scale = 1.0
         self.set_weight(qweight)
-        if bias:
-            self.bias = torch._empty_affine_quantized([out_channels],
-                                                      scale=1, zero_point=0,
-                                                      dtype=torch.qint32)
-        else:
-            self.bias = None
-
+        self.bias = torch._empty_affine_quantized([out_channels],
+                                                  scale=1, zero_point=0,
+                                                  dtype=torch.qint32)
         self.scale = 1.0
         self.zero_point = 0
 
@@ -111,7 +106,6 @@
     def set_weight(self, w):
         self._packed_weight = torch.ops.quantized.fbgemm_conv_prepack(
             w.permute([0, 2, 3, 1]), self.stride, self.padding, self.dilation, self.groups)
-        self.weight_scale = w.q_scale()
 
     def weight(self):
         return torch.ops.quantized.fbgemm_conv_unpack(
@@ -122,13 +116,8 @@
         # https://github.com/pytorch/pytorch/issues/23890
         if len(input.shape) != 4:
             raise ValueError("Input shape must be `(N, C, H, W)`!")
-        # Temporary work around for bias
-        # see Issue:https://github.com/pytorch/pytorch/issues/23874
-        bias = self.bias
-        if bias is not None:
-            bias = torch.quantize_linear(bias.dequantize(), self.weight_scale * input.q_scale(), 0, torch.qint32)
         output = ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]),
-                                             self._packed_weight, bias,
+                                             self._packed_weight, self.bias,
                                              self.stride, self.padding,
                                              self.dilation, self.groups,
                                              self.scale, self.zero_point)
@@ -162,7 +151,7 @@
             self.weight(),
             self.bias,
             self.scale,
-            self.zero_point
+            self.zero_point,
         )
 
     # ===== Deserialization methods =====
@@ -238,8 +227,7 @@
         act_scale, act_zp = activation_observer.calculate_qparams()
         assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8'
         wt_scale, wt_zp = weight_observer.calculate_qparams()
-        bias_scale = float(act_scale / (2**16))
-
+        bias_scale = float(wt_scale * act_scale)
         qweight = torch.quantize_linear(
             mod.weight.float(),
             float(wt_scale), int(wt_zp), torch.qint8)
@@ -248,11 +236,9 @@
                     mod.bias is not None, mod.padding_mode)
         qconv.set_weight(qweight)
         if mod.bias is not None:
-            qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
+            qconv.bias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
         else:
-            qbias = None
-        qconv.bias = qbias
+            qconv.bias = None
         qconv.scale = float(act_scale)
         qconv.zero_point = int(act_zp)
-
         return qconv
diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py
index a31b29c..6c7a682 100644
--- a/torch/nn/quantized/modules/linear.py
+++ b/torch/nn/quantized/modules/linear.py
@@ -2,7 +2,7 @@
 
 import torch
 
-from torch._jit_internal import Optional
+from torch._jit_internal import Optional, Tuple
 import torch.nn as nn
 import torch.nn._intrinsic as nni
 from torch.nn.modules import Module
@@ -120,7 +120,6 @@
             [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
 
         self.set_weight(qweight)
-        self.weight_scale = 1.0
         self.scale = 1.0
         self.zero_point = 0
 
@@ -130,14 +129,8 @@
         )
 
     def forward(self, x):
-        # Temporary work around for bias
-        # see Issue:https://github.com/pytorch/pytorch/issues/23874
-        bias = self.bias
-        if bias is not None:
-            bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * x.q_scale(), 0, torch.qint32)
-
         return torch.ops.quantized.fbgemm_linear(
-            x, self._packed_weight, bias, self.scale, self.zero_point)
+            x, self._packed_weight, self.bias, self.scale, self.zero_point)
 
     # ===== Serialization methods =====
     # The special consideration here is that we have to unpack the weights into their
@@ -199,7 +192,6 @@
 
     def set_weight(self, w):
         self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w)
-        self.weight_scale = w.q_scale()
 
     @classmethod
     def from_float(cls, mod):
@@ -230,7 +222,7 @@
         act_scale, act_zp = activation_observer.calculate_qparams()
         assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
         wt_scale, wt_zp = weight_observer.calculate_qparams()
-        bias_scale = float(act_scale / (2**16))
+        bias_scale = float(wt_scale * act_scale)
         qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
         if mod.bias is not None:
             qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)