[quant][pt2e] Add reference representation rewrite for statically quantized linear (#107994)
Summary: att
Test Plan:
```
python test/test_quantization.py TestQuantizePT2E.test_representation_linear
buck2 test 'fbcodemode/opt' fbcodecaffe2/test:quantization_pt2e -- 'test_representation_linear'
```
Reviewed By: kimishpatel
Differential Revision: D48674862
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107994
Approved by: https://github.com/mcr229, https://github.com/guangy10
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 2760a99..321b5a9 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -2114,7 +2114,28 @@
M(), example_inputs, is_per_channel=True, verify_convert=True,
)
- @unittest.skip("some issues with conv2d rewrite, will fix in a separate PR")
+ def test_representation_linear(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 5)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ quantizer = XNNPACKQuantizer()
+ operator_config = get_symmetric_quantization_config(is_per_channel=False)
+ quantizer.set_global(operator_config)
+ example_inputs = (torch.randn(2, 5),)
+
+ self._test_representation(
+ M().eval(),
+ example_inputs,
+ quantizer,
+ ref_node_occurrence={},
+ non_ref_node_occurrence={}
+ )
+
def test_representation_conv2d(self):
class M(torch.nn.Module):
def __init__(self):
diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py
index 5183fd0..131f7c0 100644
--- a/torch/_higher_order_ops/out_dtype.py
+++ b/torch/_higher_order_ops/out_dtype.py
@@ -22,6 +22,7 @@
# TODO to figure out a more generic approach
ALLOWABLE_OPS = [
+ torch.ops.aten.linear.default,
torch.ops.aten.mm.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.convolution.default,
diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py
index 5abc894..c7c2792 100644
--- a/torch/ao/quantization/pt2e/representation/rewrite.py
+++ b/torch/ao/quantization/pt2e/representation/rewrite.py
@@ -18,6 +18,73 @@
"reference_representation_rewrite",
]
+
+_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+ torch.randint(-128, 127, (2, 5), dtype=torch.int8),
+ torch.randn(1, dtype=torch.float),
+ torch.zeros(1, dtype=torch.int),
+ torch.tensor([-128], dtype=torch.int),
+ torch.tensor([127], dtype=torch.int),
+ torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+ torch.randn(1, dtype=torch.float),
+ torch.zeros(1, dtype=torch.int),
+ torch.tensor([-127], dtype=torch.int),
+ torch.tensor([127], dtype=torch.int),
+ torch.randn(1, dtype=torch.float),
+ torch.randn(1, dtype=torch.float),
+ torch.zeros(1, dtype=torch.int),
+ torch.tensor([-128], dtype=torch.int),
+ torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_linear(
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+ bias_fp32,
+ out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+ weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
+ out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+ out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
+ return out_i8
+
+def _reference_quantized_linear(
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+ bias_fp32,
+ out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+ # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+ # This results in failure to match the pattern.
+ # Therefore, we call a torch.ops.aten.clamp here
+ x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+ weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+ x_i16 = x_i8.to(torch.int16)
+ weight_i16 = weight_i8.to(torch.int16)
+ # always set bias to None so that the same representation can work for the case
+ # no matter if bias_scale == x_scale * weight_scale or not
+ acc_i32 = out_dtype(
+ torch.ops.aten.linear.default,
+ torch.int32,
+ x_i16 - x_zero_point,
+ weight_i16 - weight_zero_point,
+ None)
+ # TODO: change to mul.Scalar
+ # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+ bias_scale = x_scale * weight_scale
+ bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+ acc_i32 = acc_i32 + bias_i32
+ # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+ acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
+ out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+ return out_i8
+
+
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
@@ -399,6 +466,13 @@
_REWRITE_INFO_LIST = [
_RewriteInfo(
+ _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+ _qdq_quantized_linear,
+ _reference_quantized_linear,
+ _replace_literals_with_new_placeholders,
+ _replace_literals_with_new_placeholders,
+ ),
+ _RewriteInfo(
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
_qdq_quantized_conv2d,
_reference_quantized_conv2d,