Revert D16852280: Work around for bias quantization for conv and linear operators
Test Plan: revert-hammer
Differential Revision:
D16852280
Original commit changeset: 988f8ff91616
fbshipit-source-id: e2cf03e13dc8dcf0db22d43740d72fd8b069fd74
diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py
index e0230f2..8c97fe9 100644
--- a/test/test_quantized_nn_mods.py
+++ b/test/test_quantized_nn_mods.py
@@ -447,10 +447,7 @@
# Smoke test to make sure the module actually runs
quantized_float_conv(qX)
- # Check that bias is quantized based on output scale
- if use_bias:
- qbias = torch.quantize_linear(float_conv.bias, quantized_float_conv.scale / 2**16, 0, torch.qint32)
- self.assertEqual(quantized_float_conv.bias.dequantize(), qbias.dequantize())
+
# Smoke test extra_repr
str(quantized_float_conv)
diff --git a/torch/nn/_intrinsic/quantized/modules/conv_relu.py b/torch/nn/_intrinsic/quantized/modules/conv_relu.py
index 5bf6b17..2c14494 100644
--- a/torch/nn/_intrinsic/quantized/modules/conv_relu.py
+++ b/torch/nn/_intrinsic/quantized/modules/conv_relu.py
@@ -33,20 +33,14 @@
self.padding,
self.dilation,
self.groups)
- self.weight_scale = w.q_scale()
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
- # Temporary work around for bias
- # see Issue:https://github.com/pytorch/pytorch/issues/23874
- bias = self.bias
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32)
output = torch.ops.quantized.fbgemm_conv2d_relu(input.permute([0, 2, 3, 1]),
- self._packed_weight, bias,
+ self._packed_weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
float(self.scale), int(self.zero_point))
diff --git a/torch/nn/_intrinsic/quantized/modules/linear_relu.py b/torch/nn/_intrinsic/quantized/modules/linear_relu.py
index 0714e01..0511b70 100644
--- a/torch/nn/_intrinsic/quantized/modules/linear_relu.py
+++ b/torch/nn/_intrinsic/quantized/modules/linear_relu.py
@@ -26,13 +26,9 @@
super(LinearReLU, self).__init__(in_features, out_features, bias)
def forward(self, input):
- bias = self.bias
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32)
-
Y_q = torch.ops.quantized.fbgemm_linear_relu(
input, self._packed_weight,
- bias,
+ self.bias,
float(self.scale),
int(self.zero_point))
return Y_q
diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py
index 08b963f..7954014 100644
--- a/torch/nn/quantized/functional.py
+++ b/torch/nn/quantized/functional.py
@@ -54,8 +54,6 @@
if zero_point is None:
zero_point = input.q_zero_point()
_packed_weight = torch.ops.quantized.fbgemm_linear_prepack(weight)
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), weight.q_scale() * input.q_scale(), 0, torch.qint32)
return torch.ops.quantized.fbgemm_linear(input, _packed_weight, bias, scale,
zero_point)
@@ -118,8 +116,6 @@
prepacked_weight = torch.ops.quantized.fbgemm_conv_prepack(
weight.permute([0, 2, 3, 1]), stride, padding, dilation, groups)
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), weight.q_scale() * input.q_scale(), 0, torch.qint32)
return torch.ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]),
prepacked_weight, bias,
stride, padding, dilation,
diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py
index 418d61d..f23665d 100644
--- a/torch/nn/quantized/modules/conv.py
+++ b/torch/nn/quantized/modules/conv.py
@@ -83,15 +83,10 @@
[out_channels, in_channels // self.groups, self.kernel_size[0],
self.kernel_size[1]],
scale=1, zero_point=0, dtype=torch.qint8)
- self.weight_scale = 1.0
self.set_weight(qweight)
- if bias:
- self.bias = torch._empty_affine_quantized([out_channels],
- scale=1, zero_point=0,
- dtype=torch.qint32)
- else:
- self.bias = None
-
+ self.bias = torch._empty_affine_quantized([out_channels],
+ scale=1, zero_point=0,
+ dtype=torch.qint32)
self.scale = 1.0
self.zero_point = 0
@@ -111,7 +106,6 @@
def set_weight(self, w):
self._packed_weight = torch.ops.quantized.fbgemm_conv_prepack(
w.permute([0, 2, 3, 1]), self.stride, self.padding, self.dilation, self.groups)
- self.weight_scale = w.q_scale()
def weight(self):
return torch.ops.quantized.fbgemm_conv_unpack(
@@ -122,13 +116,8 @@
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
- # Temporary work around for bias
- # see Issue:https://github.com/pytorch/pytorch/issues/23874
- bias = self.bias
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), self.weight_scale * input.q_scale(), 0, torch.qint32)
output = ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]),
- self._packed_weight, bias,
+ self._packed_weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.scale, self.zero_point)
@@ -162,7 +151,7 @@
self.weight(),
self.bias,
self.scale,
- self.zero_point
+ self.zero_point,
)
# ===== Deserialization methods =====
@@ -238,8 +227,7 @@
act_scale, act_zp = activation_observer.calculate_qparams()
assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8'
wt_scale, wt_zp = weight_observer.calculate_qparams()
- bias_scale = float(act_scale / (2**16))
-
+ bias_scale = float(wt_scale * act_scale)
qweight = torch.quantize_linear(
mod.weight.float(),
float(wt_scale), int(wt_zp), torch.qint8)
@@ -248,11 +236,9 @@
mod.bias is not None, mod.padding_mode)
qconv.set_weight(qweight)
if mod.bias is not None:
- qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
+ qconv.bias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
else:
- qbias = None
- qconv.bias = qbias
+ qconv.bias = None
qconv.scale = float(act_scale)
qconv.zero_point = int(act_zp)
-
return qconv
diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py
index a31b29c..6c7a682 100644
--- a/torch/nn/quantized/modules/linear.py
+++ b/torch/nn/quantized/modules/linear.py
@@ -2,7 +2,7 @@
import torch
-from torch._jit_internal import Optional
+from torch._jit_internal import Optional, Tuple
import torch.nn as nn
import torch.nn._intrinsic as nni
from torch.nn.modules import Module
@@ -120,7 +120,6 @@
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
self.set_weight(qweight)
- self.weight_scale = 1.0
self.scale = 1.0
self.zero_point = 0
@@ -130,14 +129,8 @@
)
def forward(self, x):
- # Temporary work around for bias
- # see Issue:https://github.com/pytorch/pytorch/issues/23874
- bias = self.bias
- if bias is not None:
- bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * x.q_scale(), 0, torch.qint32)
-
return torch.ops.quantized.fbgemm_linear(
- x, self._packed_weight, bias, self.scale, self.zero_point)
+ x, self._packed_weight, self.bias, self.scale, self.zero_point)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
@@ -199,7 +192,6 @@
def set_weight(self, w):
self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w)
- self.weight_scale = w.q_scale()
@classmethod
def from_float(cls, mod):
@@ -230,7 +222,7 @@
act_scale, act_zp = activation_observer.calculate_qparams()
assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
wt_scale, wt_zp = weight_observer.calculate_qparams()
- bias_scale = float(act_scale / (2**16))
+ bias_scale = float(wt_scale * act_scale)
qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
if mod.bias is not None:
qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)