Whiltelist and fusion support for quantized::linear - matmul(with bias) (#26204)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26204
Support quant fusion for `matmul` with bias to `quantized::linear`.
Test Plan:
python test/test_jit.py 'TestJit.test_quant_fusion'
Imported from OSS
Differential Revision: D17380073
fbshipit-source-id: 00014469a852cc5d5b66469fc4b8d05eafba1e3e
diff --git a/test/test_jit.py b/test/test_jit.py
index 0d76c19..bcae2db 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1180,6 +1180,38 @@
# CHECK: aten::_dequantize_linear
%r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
return (%r_dequant)""",
+ # matmul(with bias) -> quantized::linear
+ """
+graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
+%b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
+ %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype)
+ # CHECK-NOT: aten::int_repr
+ %a_intrepr = aten::int_repr(%a_quant)
+ # CHECK-NOT: aten::_dequantize_linear
+ %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
+ %w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype)
+ # CHECK-NOT: aten::int_repr
+ %w_intrepr = aten::int_repr(%w_quant)
+ # CHECK-NOT: aten::_dequantize_linear
+ %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
+ # CHECK-NOT: aten::int_repr
+ %b_quant = aten::quantize_linear(%b, %b_scale, %b_zero_point, %b_dtype)
+ %b_intrepr = aten::int_repr(%b_quant)
+ # CHECK-NOT: aten::_dequantize_linear
+ %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
+ # CHECK: aten::t
+ # CHECK: quantized::fbgemm_linear_prepack
+ # CHECK: quantized::fbgemm_linear
+ # CHECK-NOT: aten::addmm
+ %output = aten::matmul(%a_dequant, %w_dequant)
+ %r = aten::add_(%output, %b_dequant, %4)
+ # CHECK-NOT: aten::quantize_linear
+ %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
+ # CHECK: aten::int_repr
+ %r_intrepr = aten::int_repr(%r_quant)
+ # CHECK: aten::_dequantize_linear
+ %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
+ return (%r_dequant)""",
# matmul(without bias) -> quantized::linear
"""
graph(%a, %w, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 92ea996..26c0460 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -58,7 +58,14 @@
%relu = match::module[name="ReLU"](%self)
%r = prim::CallMethod[name="forward"](%relu, %output)
return (%r))";
- std::vector<std::string> patterns = {conv_functional_relu, conv_relu_module};
+ std::string matmul_add = R"(
+graph(%input, %weight, %bias, %4):
+ %weight_t = aten::t(%weight)
+ %output = aten::matmul(%input, %weight_t)
+ %res = aten::add_(%output, %bias, %4)
+ return (%res))";
+ std::vector<std::string> patterns = {
+ conv_functional_relu, conv_relu_module, matmul_add};
for (const auto& pattern : patterns) {
findValuesInPattern(*graph, pattern, values_to_skip);
@@ -81,8 +88,8 @@
"linear",
"relu",
};
- std::vector<Symbol> aten_funcs = {Symbol::aten("addmm"),
- Symbol::aten("matmul")};
+ std::vector<Symbol> aten_funcs = {
+ Symbol::aten("addmm"), Symbol::aten("matmul"), Symbol::aten("add_")};
std::transform(
call_funcs.begin(),
call_funcs.end(),
@@ -652,6 +659,22 @@
}
void QuantFusion(std::shared_ptr<Graph>& graph) {
+ const std::string quantized_linear_with_bias =
+ R"(
+graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
+ %0 : int = prim::Constant[value=0]()
+ %1 : int = prim::Constant[value=1]()
+ %2 : int = prim::Constant[value=2]()
+ %3 : int = prim::Constant[value=3]()
+ %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
+ %a_perm : Tensor = aten::permute(%a_quant, %in_param)
+ %w_quant_t = aten::t(%w_quant)
+ %w_perm : Tensor = aten::permute(%w_quant_t, %in_param)
+ %w_packed = quantized::fbgemm_linear_prepack(%w_perm)
+ %r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
+ %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
+ %r_perm = aten::permute(%r, %out_param)
+ return (%r_perm))";
const std::unordered_map<std::string, std::string> pattern_and_replacements =
{// quantized::conv2d
{R"(
@@ -691,21 +714,21 @@
%r = aten::addmm(%b_dequant, %a_dequant, %w_dequant, %4, %4)
%r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
return (%r_quant))",
- R"(
+ quantized_linear_with_bias},
+ // matmul(with bias) -> quantized::linear
+ {R"(
graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
- %0 : int = prim::Constant[value=0]()
- %1 : int = prim::Constant[value=1]()
- %2 : int = prim::Constant[value=2]()
- %3 : int = prim::Constant[value=3]()
- %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
- %a_perm : Tensor = aten::permute(%a_quant, %in_param)
- %w_quant_t = aten::t(%w_quant)
- %w_perm : Tensor = aten::permute(%w_quant_t, %in_param)
- %w_packed = quantized::fbgemm_linear_prepack(%w_perm)
- %r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
- %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
- %r_perm = aten::permute(%r, %out_param)
- return (%r_perm))"},
+ %a_intrepr = aten::int_repr(%a_quant)
+ %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
+ %w_intrepr = aten::int_repr(%w_quant)
+ %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
+ %b_intrepr = aten::int_repr(%b_quant)
+ %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
+ %output = aten::matmul(%a_dequant, %w_dequant)
+ %r = aten::add_(%output, %b_dequant, %4)
+ %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
+ return (%r_quant))",
+ quantized_linear_with_bias},
// matmul(without bias) -> quantized::linear
{R"(
graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):