[quant][fx] Allow incrementally remove the items in quantization_patterns.py (#74210)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74210
This PR added a codepath for getting patterns (quantize handlers) for the backend_config_dict for native backend when
backend_config_dict is None. This would allow us to incrementally define the backend_config_dict for
pytorch native backend and gradually remove the entries in quantization_patterns.py
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Imported from OSS
Reviewed By: dzdang
Differential Revision: D34899783
fbshipit-source-id: 7f31292948d7fc4566e51e175b41511f52d0a880
(cherry picked from commit a9f6ebd6478f362d5bb9c5ae04e02369e00f550c)
diff --git a/torch/ao/quantization/fx/backend_config/__init__.py b/torch/ao/quantization/fx/backend_config/__init__.py
index b595b66..3fc6762 100644
--- a/torch/ao/quantization/fx/backend_config/__init__.py
+++ b/torch/ao/quantization/fx/backend_config/__init__.py
@@ -1,4 +1,5 @@
from .tensorrt import get_tensorrt_backend_config_dict
+from .native import get_native_backend_config_dict
# TODO: add more validations
def validate_backend_config_dict(backend_config_dict):
diff --git a/torch/ao/quantization/fx/backend_config/native.py b/torch/ao/quantization/fx/backend_config/native.py
new file mode 100644
index 0000000..f743ec6
--- /dev/null
+++ b/torch/ao/quantization/fx/backend_config/native.py
@@ -0,0 +1,44 @@
+import torch
+from .observation_type import ObservationType
+import torch.nn.qat as nnqat
+
+def get_native_backend_config_dict():
+ """ Get backend for PyTorch Native backend_config_dict (fbgemm/qnnpack)
+ """
+ # dtype configs
+
+ # weighted op int8 config
+ # activation: quint8, weight: qint8, bias: float
+ weighted_op_int8_dtype_config = {
+ # optional, input activation dtype
+ "input_dtype": torch.quint8,
+ # optional, weight dtype
+ "weight_dtype": torch.qint8,
+ # optional, bias dtype
+ "bias_dtype": torch.float,
+ # optional, output activation dtype
+ "output_dtype": torch.quint8
+ }
+ # operator (module/functional/torch ops) configs
+ linear_module_config = {
+ # Please see README under this folder for pattern format
+ "pattern": torch.nn.Linear,
+ "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+ "dtype_configs": [
+ weighted_op_int8_dtype_config,
+ ],
+ # the root module for the pattern, used to query the reference quantized module
+ # e.g. for a (torch.nn.ReLU, torch.nn.Linear) pattern, the root will be torch.nn.Linear
+ "root_module": torch.nn.Linear,
+ # the corresponding reference quantized module for the root module
+ "reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
+ "qat_module": nnqat.Linear,
+ }
+
+ return {
+ # optional
+ "name": "native",
+ "configs": [
+ linear_module_config,
+ ],
+ }
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index bda39ff..1fd5f8e 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -90,6 +90,9 @@
get_pattern_to_input_type_to_index,
get_module_to_qat_module,
)
+from .backend_config import (
+ get_native_backend_config_dict,
+)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
from collections import defaultdict
@@ -1357,11 +1360,17 @@
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.ao.quantization.fx.quantize.Add'>),
# }
+ # TODO: rename to pattern_to_quantize_handler
patterns: Dict[Pattern, QuantizeHandler] = {}
if backend_config_dict is None:
quant_patterns = get_default_quant_patterns()
patterns = get_combined_dict(
quant_patterns, additional_quant_patterns)
+ # TODO: currently we just extend the quantize handlers generated from
+ # `get_native_backend_config_dict`
+ # in the future we can just assign backend_config_dict when everything is defined
+ for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
+ patterns[pattern] = quantize_handler
else:
patterns = get_pattern_to_quantize_handlers(backend_config_dict)
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index e36ca7a..df7a95e 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -781,7 +781,6 @@
# conv2d_dyanmic branch
raise Exception("Only static quant is supported for conv")
-@register_quant_pattern(torch.nn.Linear)
@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)