[quant][pt2e] Fix propagate_annotation after recent refactors (#102422)
Summary:
Recently we changed the annotation from "target_dtype_info" to "quantization_annotation" and introduced QuantizationAnnotation API
and SharedQuantizationSpec API for users to convey sharing between input/outputs, this PR updates the _propagate_annotation
pass to accommadate the recent changes
Test Plan:
```
buck2 test mode/opt caffe2/test:quantization_pt2e -- 'caffe2/test:quantization_pt2e'
```
Reviewed By: kimishpatel
Differential Revision: D46153084
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102422
Approved by: https://github.com/kimishpatel
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index f7537d7..14c0145 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -665,6 +665,50 @@
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
+ def test_propagate_annotation(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 3, 3)
+ self.linear = torch.nn.Linear(3, 3)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = x.view(-1, 3)
+ x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
+ x = self.linear(x)
+ return x
+
+ import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
+ quantizer = QNNPackQuantizer()
+ operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
+ quantizer.set_global(operator_config)
+ m = M().eval()
+ example_inputs = (torch.randn(1, 3, 5, 5),)
+
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ )
+
+ m = prepare_pt2e_quantizer(m, quantizer)
+ m(*example_inputs)
+ self.assertEqual(id(m.activation_post_process_2), id(m.activation_post_process_3))
+ self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_4))
+ m = convert_pt2e(m)
+ node_occurrence = {
+ # input and output are using quantize_per_tensor and weight is using quantize_per_channel
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 5,
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 2,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 2,
+ }
+ self.checkGraphModuleNodes(
+ m, expected_node_occurrence=node_occurrence
+ )
+
def test_prepare_qat_conv_bn_fusion(self):
class M(torch.nn.Module):
def __init__(self):
diff --git a/torch/ao/quantization/_pt2e/_propagate_annotation.py b/torch/ao/quantization/_pt2e/_propagate_annotation.py
index 17df6a8..7508430 100644
--- a/torch/ao/quantization/_pt2e/_propagate_annotation.py
+++ b/torch/ao/quantization/_pt2e/_propagate_annotation.py
@@ -3,8 +3,13 @@
from typing import (
Callable,
)
+from torch.ao.quantization._pt2e.quantizer import (
+ QuantizationAnnotation,
+ SharedQuantizationSpec,
+)
def _is_share_obs_or_fq_op(op: Callable) -> bool:
+ # TODO: remove some of these ops in qnnpack_quantizer
return op in [
torch.ops.aten.hardtanh.default,
torch.ops.aten.mean.default,
@@ -23,22 +28,24 @@
if not isinstance(prev_node, Node):
continue
- target_dtype_info = prev_node.meta.get("target_dtype_info", None)
- if not target_dtype_info:
+ quantization_annotation = prev_node.meta.get("quantization_annotation", None)
+ if not quantization_annotation:
continue
- output_act_obs_or_fq_ctr = target_dtype_info.get("output_act_obs_or_fq_ctr", None)
- if not output_act_obs_or_fq_ctr:
+ output_qspec = quantization_annotation.output_qspec
+ if not output_qspec:
continue
# make sure current node is not annotated
- if "target_dtype_info" in n.meta and n.meta["target_dtype_info"].get("_annotated", False):
+ if "quantization_annotation" in n.meta and n.meta["quantization_annotation"]._annotated:
continue
- # propagate the previous output_act_obs_or_fq to the current node
- n.meta["target_dtype_info"] = {
- "input_act_obs_or_fq_ctr": output_act_obs_or_fq_ctr,
- "output_act_obs_or_fq_ctr": output_act_obs_or_fq_ctr,
- "input_output_share_observers": True,
- "_annotated": True,
- }
+ shared_qspec = SharedQuantizationSpec(prev_node)
+ # propagate the previous output_qspec to the current node
+ n.meta["quantization_annotation"] = QuantizationAnnotation(
+ input_qspec_map={
+ prev_node: shared_qspec,
+ },
+ output_qspec=shared_qspec,
+ _annotated=True,
+ )