[Quant] Add dynamic quantization config for x86 inductor backend (#115337)
**Description**
Add dynamic quantization config for x86 inductor backend.
To support the QKV structure in self-attention, we removed an assertion in port-metadata-pass that requires single dequantize node after quantize node.
**Test plan**
```
python test/test_quantization.py -k TestQuantizePT2EX86Inductor.test_dynamic_quant_linear
python test/test_quantization.py -k TestQuantizePT2EX86Inductor.test_qat_dynamic_quant_linear
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115337
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py
index 6112626..c2bf116 100644
--- a/test/quantization/pt2e/test_x86inductor_quantizer.py
+++ b/test/quantization/pt2e/test_x86inductor_quantizer.py
@@ -282,6 +282,24 @@
tmp = self.bn(self.conv(x))
return tmp + self.bn2(self.conv2(tmp))
+ class SelfAttnLikeModule(torch.nn.Module):
+ def __init__(self, input_dim) -> None:
+ super().__init__()
+ self.input_dim = input_dim
+ self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
+ self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
+ self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x):
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+ scores = torch.bmm(q, k.transpose(1, 2)) / (self.input_dim ** 0.5)
+ attention = self.softmax(scores)
+ weighted = torch.bmm(attention, v)
+ return weighted
+
class X86InductorQuantTestCase(QuantizationTestCase):
def _test_quantizer(
self,
@@ -1199,3 +1217,75 @@
node_list,
is_qat=True,
)
+
+ @skipIfNoX86
+ def test_dynamic_quant_linear(self):
+ """
+ Test pattern of dynamic quantization of linear with X86InductorQuantizer.
+ """
+ with override_quantized_engine("x86"), torch.no_grad():
+ m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
+ example_inputs = (torch.randn(1, 4, 64),)
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config(is_dynamic=True)
+ )
+ node_occurrence = {
+ torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
+ # quantize_per_channel for weights are const propagated
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.choose_qparams.tensor,
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
+ torch.ops.aten.linear.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
+
+ @skipIfNoX86
+ def test_qat_dynamic_quant_linear(self):
+ """
+ Test pattern of qat dynamic quantization of linear with X86InductorQuantizer.
+ """
+ with override_quantized_engine("x86"), torch.no_grad():
+ m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
+ example_inputs = (torch.randn(1, 4, 64),)
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config(
+ is_qat=True,
+ is_dynamic=True
+ )
+ )
+ node_occurrence = {
+ torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
+ # quantize_per_channel for weights are const propagated
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.choose_qparams.tensor,
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
+ torch.ops.aten.linear.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ is_qat=True,
+ )
diff --git a/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/torch/ao/quantization/pt2e/duplicate_dq_pass.py
index 6780f99..48c7d72 100644
--- a/torch/ao/quantization/pt2e/duplicate_dq_pass.py
+++ b/torch/ao/quantization/pt2e/duplicate_dq_pass.py
@@ -1,4 +1,5 @@
import logging
+import operator
import torch
@@ -16,6 +17,12 @@
__all__ = ["DuplicateDQPass"]
+_QUANTIZE_OPS = [
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
@@ -51,6 +58,24 @@
dq_users = _filter_sym_size_users(node)
if len(dq_users) <= 1:
continue
+ # Do not duplicate dq for dynamic quantization
+ # Pattern: choose_qparam - getitem - q - dq
+ q_node = node.args[0]
+ if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
+ getitem_node = q_node.args[1]
+ if (
+ isinstance(getitem_node, torch.fx.node.Node)
+ and getitem_node.op == "call_function"
+ and getitem_node.target == operator.getitem
+ ):
+ choose_qparam_node = getitem_node.args[0]
+ if (
+ isinstance(choose_qparam_node, torch.fx.node.Node)
+ and choose_qparam_node.op == "call_function"
+ and choose_qparam_node.target
+ == torch.ops.quantized_decomposed.choose_qparams.tensor
+ ):
+ continue
for user in dq_users:
_maybe_duplicate_dq(graph_module, node, user)
graph_module.graph.eliminate_dead_code()
diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
index 4f0aa7f..86a8738 100644
--- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
+++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
@@ -7,9 +7,13 @@
import torch
import torch.nn.functional as F
-from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
+from torch.ao.quantization.fake_quantize import (
+ FakeQuantize,
+ FusedMovingAvgObsFakeQuantize,
+)
from torch.ao.quantization.observer import (
HistogramObserver,
+ MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
@@ -164,10 +168,25 @@
@functools.lru_cache
-def get_default_x86_inductor_quantization_config(is_qat: bool = False):
- act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
- FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver
- )
+def get_default_x86_inductor_quantization_config(
+ is_qat: bool = False,
+ is_dynamic: bool = False,
+):
+ extra_args: Dict[str, Any] = {"eps": 2**-12}
+ if is_qat:
+ if is_dynamic:
+ act_observer_or_fake_quant_ctr = FakeQuantize
+ dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
+ averaging_constant=1
+ )
+ extra_args["observer"] = dynamic_quant_observer
+ else:
+ act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
+ else:
+ if is_dynamic:
+ act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
+ else:
+ act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
# Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
act_quantization_spec = QuantizationSpec(
@@ -175,9 +194,9 @@
quant_min=0,
quant_max=255, # reduce_range=False
qscheme=torch.per_tensor_affine,
- is_dynamic=False,
+ is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
- eps=2**-12
+ **extra_args
),
)
@@ -185,7 +204,6 @@
FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
)
- extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
# Only support per channel quant for now
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
@@ -200,12 +218,7 @@
**extra_args
),
)
- bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
- PlaceholderObserver
- )
- bias_quantization_spec = QuantizationSpec(
- dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
- )
+ bias_quantization_spec = None # will use placeholder observer by default
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
@@ -370,7 +383,10 @@
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
- model = self._annotate_for_static_quantization_config(model)
+ if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
+ model = self._annotate_for_dynamic_quantization_config(model)
+ else:
+ model = self._annotate_for_static_quantization_config(model)
return model
def _annotate_for_static_quantization_config(
@@ -412,6 +428,13 @@
return model
+ def _annotate_for_dynamic_quantization_config(
+ self, model: torch.fx.GraphModule
+ ) -> torch.fx.GraphModule:
+ config = self.global_config
+ self._annotate_linear(model, config)
+ return model
+
def _annotate_qat_conv2d_fusion_pattern(
self, model: torch.fx.GraphModule, config: QuantizationConfig
):
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index 49b1b81..c5ab78c 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -152,7 +152,7 @@
torch.int16: torch.int16,
torch.int32: torch.int32,
}
- assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
+ assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
return DTYPE_MAPPING[qdtype]
def get_qparam_dict(observer_or_fake_quant):