Whiltelist and fusion support for quantized::linear - matmul(with bias) (#26204)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26204

Support quant fusion for `matmul` with bias to `quantized::linear`.

Test Plan:
python test/test_jit.py 'TestJit.test_quant_fusion'

Imported from OSS

Differential Revision: D17380073

fbshipit-source-id: 00014469a852cc5d5b66469fc4b8d05eafba1e3e
diff --git a/test/test_jit.py b/test/test_jit.py
index 0d76c19..bcae2db 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1180,6 +1180,38 @@
         # CHECK: aten::_dequantize_linear
         %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
         return (%r_dequant)""",
+            # matmul(with bias) -> quantized::linear
+            """
+graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
+%b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
+        %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype)
+        # CHECK-NOT: aten::int_repr
+        %a_intrepr = aten::int_repr(%a_quant)
+        # CHECK-NOT: aten::_dequantize_linear
+        %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
+        %w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype)
+        # CHECK-NOT: aten::int_repr
+        %w_intrepr = aten::int_repr(%w_quant)
+        # CHECK-NOT: aten::_dequantize_linear
+        %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
+        # CHECK-NOT: aten::int_repr
+        %b_quant = aten::quantize_linear(%b, %b_scale, %b_zero_point, %b_dtype)
+        %b_intrepr = aten::int_repr(%b_quant)
+        # CHECK-NOT: aten::_dequantize_linear
+        %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
+        # CHECK: aten::t
+        # CHECK: quantized::fbgemm_linear_prepack
+        # CHECK: quantized::fbgemm_linear
+        # CHECK-NOT: aten::addmm
+        %output = aten::matmul(%a_dequant, %w_dequant)
+        %r = aten::add_(%output, %b_dequant, %4)
+        # CHECK-NOT: aten::quantize_linear
+        %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
+        # CHECK: aten::int_repr
+        %r_intrepr = aten::int_repr(%r_quant)
+        # CHECK: aten::_dequantize_linear
+        %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
+        return (%r_dequant)""",
             # matmul(without bias) -> quantized::linear
             """
 graph(%a, %w, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 92ea996..26c0460 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -58,7 +58,14 @@
     %relu = match::module[name="ReLU"](%self)
     %r = prim::CallMethod[name="forward"](%relu, %output)
     return (%r))";
-  std::vector<std::string> patterns = {conv_functional_relu, conv_relu_module};
+  std::string matmul_add = R"(
+graph(%input, %weight, %bias, %4):
+     %weight_t = aten::t(%weight)
+     %output = aten::matmul(%input, %weight_t)
+     %res = aten::add_(%output, %bias, %4)
+     return (%res))";
+  std::vector<std::string> patterns = {
+      conv_functional_relu, conv_relu_module, matmul_add};
 
   for (const auto& pattern : patterns) {
     findValuesInPattern(*graph, pattern, values_to_skip);
@@ -81,8 +88,8 @@
       "linear",
       "relu",
   };
-  std::vector<Symbol> aten_funcs = {Symbol::aten("addmm"),
-                                    Symbol::aten("matmul")};
+  std::vector<Symbol> aten_funcs = {
+      Symbol::aten("addmm"), Symbol::aten("matmul"), Symbol::aten("add_")};
   std::transform(
       call_funcs.begin(),
       call_funcs.end(),
@@ -652,6 +659,22 @@
 }
 
 void QuantFusion(std::shared_ptr<Graph>& graph) {
+  const std::string quantized_linear_with_bias =
+      R"(
+graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
+        %0 : int = prim::Constant[value=0]()
+        %1 : int = prim::Constant[value=1]()
+        %2 : int = prim::Constant[value=2]()
+        %3 : int = prim::Constant[value=3]()
+        %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
+        %a_perm : Tensor = aten::permute(%a_quant, %in_param)
+        %w_quant_t = aten::t(%w_quant)
+        %w_perm : Tensor = aten::permute(%w_quant_t, %in_param)
+        %w_packed = quantized::fbgemm_linear_prepack(%w_perm)
+        %r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
+        %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
+        %r_perm = aten::permute(%r, %out_param)
+        return (%r_perm))";
   const std::unordered_map<std::string, std::string> pattern_and_replacements =
       {// quantized::conv2d
        {R"(
@@ -691,21 +714,21 @@
         %r = aten::addmm(%b_dequant, %a_dequant, %w_dequant, %4, %4)
         %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
         return (%r_quant))",
-        R"(
+        quantized_linear_with_bias},
+       // matmul(with bias) -> quantized::linear
+       {R"(
 graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
-        %0 : int = prim::Constant[value=0]()
-        %1 : int = prim::Constant[value=1]()
-        %2 : int = prim::Constant[value=2]()
-        %3 : int = prim::Constant[value=3]()
-        %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
-        %a_perm : Tensor = aten::permute(%a_quant, %in_param)
-        %w_quant_t = aten::t(%w_quant)
-        %w_perm : Tensor = aten::permute(%w_quant_t, %in_param)
-        %w_packed = quantized::fbgemm_linear_prepack(%w_perm)
-        %r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
-        %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
-        %r_perm = aten::permute(%r, %out_param)
-        return (%r_perm))"},
+        %a_intrepr = aten::int_repr(%a_quant)
+        %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
+        %w_intrepr = aten::int_repr(%w_quant)
+        %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
+        %b_intrepr = aten::int_repr(%b_quant)
+        %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
+        %output = aten::matmul(%a_dequant, %w_dequant)
+        %r = aten::add_(%output, %b_dequant, %4)
+        %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
+        return (%r_quant))",
+        quantized_linear_with_bias},
        // matmul(without bias) -> quantized::linear
        {R"(
 graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):