[quant][pt2e] Add reference representation rewrite for statically quantized linear (#107994)

Summary: att

Test Plan:
```
python test/test_quantization.py TestQuantizePT2E.test_representation_linear
buck2 test 'fbcodemode/opt' fbcodecaffe2/test:quantization_pt2e -- 'test_representation_linear'
```

Reviewed By: kimishpatel

Differential Revision: D48674862

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107994
Approved by: https://github.com/mcr229, https://github.com/guangy10
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 2760a99..321b5a9 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -2114,7 +2114,28 @@
             M(), example_inputs, is_per_channel=True, verify_convert=True,
         )
 
-    @unittest.skip("some issues with conv2d rewrite, will fix in a separate PR")
+    def test_representation_linear(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+
+            def forward(self, x):
+                return self.linear(x)
+
+        quantizer = XNNPACKQuantizer()
+        operator_config = get_symmetric_quantization_config(is_per_channel=False)
+        quantizer.set_global(operator_config)
+        example_inputs = (torch.randn(2, 5),)
+
+        self._test_representation(
+            M().eval(),
+            example_inputs,
+            quantizer,
+            ref_node_occurrence={},
+            non_ref_node_occurrence={}
+        )
+
     def test_representation_conv2d(self):
         class M(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py
index 5183fd0..131f7c0 100644
--- a/torch/_higher_order_ops/out_dtype.py
+++ b/torch/_higher_order_ops/out_dtype.py
@@ -22,6 +22,7 @@
 
 # TODO to figure out a more generic approach
 ALLOWABLE_OPS = [
+    torch.ops.aten.linear.default,
     torch.ops.aten.mm.default,
     torch.ops.aten.conv2d.default,
     torch.ops.aten.convolution.default,
diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py
index 5abc894..c7c2792 100644
--- a/torch/ao/quantization/pt2e/representation/rewrite.py
+++ b/torch/ao/quantization/pt2e/representation/rewrite.py
@@ -18,6 +18,73 @@
     "reference_representation_rewrite",
 ]
 
+
+_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (2, 5), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-127], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_linear(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
+    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
+    return out_i8
+
+def _reference_quantized_linear(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.linear.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None)
+    # TODO: change to mul.Scalar
+    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+    bias_scale = x_scale * weight_scale
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    acc_i32 = acc_i32 + bias_i32
+    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+    acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
+    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+    return out_i8
+
+
 _QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
     torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
     torch.randn(1, dtype=torch.float),
@@ -399,6 +466,13 @@
 
 _REWRITE_INFO_LIST = [
     _RewriteInfo(
+        _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+        _qdq_quantized_linear,
+        _reference_quantized_linear,
+        _replace_literals_with_new_placeholders,
+        _replace_literals_with_new_placeholders,
+    ),
+    _RewriteInfo(
         _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
         _qdq_quantized_conv2d,
         _reference_quantized_conv2d,