[Quant][PT2E] Enable conv2d unary and binary recipe for x86 inductor quantizer (#98826)
**Summary**
- Recipe to annotate `conv2d_relu` for `X86InductorQuantizer` is added.
- Recipe to annotate `conv2d_add` for `X86InductorQuantizer` is added.
- Recipe to annotate `conv2d_add_relu` for `X86InductorQuantizer` is added.
**Test Plan**
```
python -u -m pytest -s -v test_x86inductor_quantizer.py -k TestQuantizePT2EX86Inductor
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98826
Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py
index 3c28c8a..70cee5b 100644
--- a/test/quantization/pt2e/test_x86inductor_quantizer.py
+++ b/test/quantization/pt2e/test_x86inductor_quantizer.py
@@ -17,54 +17,401 @@
skipIfNoDynamoSupport,
)
from torch.testing._internal.common_quantized import override_quantized_engine
+from enum import Enum
+import itertools
+import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
+
+
+class Conv2DType(Enum):
+ left = 1
+ right = 2
+ both = 3
+
+class TestHelperModules:
+ class SingleConv2dModule(torch.nn.Module):
+ def __init__(self, ) -> None:
+ super().__init__()
+ self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
+
+ def forward(self, x):
+ return self.conv(x)
+
+ class Conv2dReLUModule(torch.nn.Module):
+ def __init__(self, inplace_relu: bool = False, use_bias: bool = False) -> None:
+ super().__init__()
+ self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias)
+ self.relu = nn.ReLU(inplace=inplace_relu)
+
+ def forward(self, x):
+ return self.relu(self.conv(x))
+
+ class Conv2dAddModule(torch.nn.Module):
+ def __init__(self,
+ inplace_add: bool = False,
+ conv2d_type: Conv2DType = Conv2DType.left,
+ use_bias: bool = False,
+ ) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias
+ )
+ self.conv2 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias
+ )
+ self.relu = nn.ReLU()
+ self.inplace_add = inplace_add
+ self.conv2d_type = conv2d_type
+
+ def forward(self, x):
+ if self.conv2d_type == Conv2DType.left:
+ if self.inplace_add:
+ tmp = self.conv(x)
+ tmp += self.relu(x)
+ return tmp
+ else:
+ return self.conv(x) + self.relu(x)
+ elif self.conv2d_type == Conv2DType.right:
+ if self.inplace_add:
+ tmp = self.relu(x)
+ tmp += self.conv(x)
+ return tmp
+ else:
+ return self.relu(x) + self.conv(x)
+ elif self.conv2d_type == Conv2DType.both:
+ if self.inplace_add:
+ tmp = self.conv(x)
+ tmp += self.conv2(x)
+ return tmp
+ else:
+ return self.conv(x) + self.conv2(x)
+
+ class Conv2dAddReLUModule(torch.nn.Module):
+ def __init__(self,
+ inplace_add: bool = False,
+ conv2d_type: Conv2DType = Conv2DType.left,
+ inplace_relu: bool = False,
+ use_bias: bool = False,
+ ) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias
+ )
+ self.conv2 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias
+ )
+ self.relu = nn.ReLU()
+ self.inplace_add = inplace_add
+ self.conv2d_type = conv2d_type
+ self.relu2 = nn.ReLU(inplace=inplace_relu)
+
+ def forward(self, x):
+ if self.conv2d_type == Conv2DType.left:
+ if self.inplace_add:
+ tmp = self.conv(x)
+ tmp += self.relu(x)
+ return self.relu2(tmp)
+ else:
+ return self.relu2(self.conv(x) + self.relu(x))
+ elif self.conv2d_type == Conv2DType.right:
+ if self.inplace_add:
+ tmp = self.relu(x)
+ tmp += self.conv(x)
+ return self.relu2(tmp)
+ else:
+ return self.relu2(self.relu(x) + self.conv(x))
+ elif self.conv2d_type == Conv2DType.both:
+ if self.inplace_add:
+ tmp = self.conv(x)
+ tmp += self.conv2(x)
+ return self.relu2(tmp)
+ else:
+ return self.relu2(self.conv(x) + self.conv2(x))
+
+ class SerialsConv2dAddReLUModule(torch.nn.Module):
+ """ Serials of 2 Conv2d -> Add -> ReLU Pattern.
+ """
+ def __init__(self, ) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.conv2 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.conv3 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.conv4 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.relu = nn.ReLU()
+ self.relu2 = nn.ReLU()
+
+ def forward(self, x):
+ x1 = self.conv(x)
+ res1 = self.relu(self.conv2(x1) + self.conv3(x1))
+ res2 = self.relu2(self.conv4(res1) + res1)
+ return res2
+
+class X86InductorQuantTestCase(QuantizationTestCase):
+ def _test_quantizer(
+ self,
+ model,
+ example_inputs,
+ quantizer,
+ expected_node_occurrence,
+ expected_node_list=None,
+ ):
+ m_eager = model.eval()
+
+ # program capture
+ m = copy.deepcopy(m_eager)
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ )
+ m = prepare_pt2e_quantizer(m, quantizer)
+ # Calibrate
+ m(*example_inputs)
+ m = convert_pt2e(m)
+ pt2_quant_output = m(*example_inputs)
+ node_occurrence = {
+ ns.call_function(k): v for k, v in expected_node_occurrence.items()
+ }
+ if expected_node_list is None:
+ expected_node_list = []
+ node_list = [ns.call_function(n) for n in expected_node_list]
+ self.checkGraphModuleNodes(
+ m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
+ )
@skipIfNoDynamoSupport
-class TestQuantizePT2EX86Inductor(QuantizationTestCase):
+class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
@skipIfNoX86
def test_conv2d_with_quantizer_api(self):
- class Mod(torch.nn.Module):
- def __init__(self, ) -> None:
- super().__init__()
- self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
+ """
+ Test pattern of single conv2d with X86InductorQuantizer.
+ """
+ with override_quantized_engine("x86"), torch.no_grad():
+ m = TestHelperModules.SingleConv2dModule().eval()
+ example_inputs = (torch.randn(2, 3, 16, 16),)
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config()
+ )
+ node_occurrence = {
+ # one for input and weight of the conv, one for output for the conv
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
- def forward(self, x):
- return self.conv(x)
-
- with override_quantized_engine("x86"):
- with torch.no_grad():
- m = Mod().eval()
- m_copy = copy.deepcopy(m)
+ @skipIfNoX86
+ def test_conv2d_unary_with_quantizer_api(self):
+ """
+ Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
+ Currently, only relu as unary post op is supported.
+ """
+ inplace_relu_list = [True, False]
+ use_bias_list = [True, False]
+ with override_quantized_engine("x86"), torch.no_grad():
+ for inplace_relu, use_bias in itertools.product(inplace_relu_list, use_bias_list):
+ m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, use_bias=use_bias).eval()
example_inputs = (torch.randn(2, 3, 16, 16),)
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config()
)
-
- before_fusion_result = m(*example_inputs)
- import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
- quantizer = X86InductorQuantizer()
- operator_config = xiq.get_default_x86_inductor_quantization_config()
- quantizer.set_global(operator_config)
- # Insert Observer
- m = prepare_pt2e_quantizer(m, quantizer)
- after_prepare_result = m(*example_inputs)
- m = convert_pt2e(m)
node_occurrence = {
- # one for input and weight of the conv, one for output for the conv
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
+ # one for input and weight of the conv, one for output for the relu
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
- ns.call_function(torch.ops.aten.convolution.default),
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
- self.checkGraphModuleNodes(m,
- expected_node_occurrence=node_occurrence,
- expected_node_list=node_list)
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
+
+ @skipIfNoX86
+ def test_conv2d_binary_with_quantizer_api(self):
+ """
+ Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
+ Currently, only add as binary post op is supported.
+ """
+ inplace_add_list = [True, False]
+ conv2d_type_list = [Conv2DType.left, Conv2DType.right, Conv2DType.both]
+ use_bias_list = [True, False]
+
+ with override_quantized_engine("x86"), torch.no_grad():
+ for inplace_add, conv2d_type, use_bias in itertools.product(inplace_add_list, conv2d_type_list, use_bias_list):
+ m = TestHelperModules.Conv2dAddModule(inplace_add=inplace_add, conv2d_type=conv2d_type, use_bias=use_bias).eval()
+ example_inputs = (torch.randn(2, 3, 16, 16),)
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config()
+ )
+ if conv2d_type != Conv2DType.both:
+ node_occurrence = {
+ # one for input and weight of the conv
+ # one for output for the add
+ # one for extra input node of add
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
+ }
+ else:
+ node_occurrence = {
+ # one for input and weight of the conv
+ # one for input and weight of another conv
+ # one for output for the add
+ # 2 conv will share same input quant/dequant
+ # one for extra input node of add
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
+
+ @skipIfNoX86
+ def test_conv2d_binary_unary_with_quantizer_api(self):
+ """
+ Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
+ Currently, only add as binary post op and relu as unary post op are supported.
+ """
+ inplace_add_list = [True, False]
+ conv2d_type_list = [Conv2DType.left, Conv2DType.right, Conv2DType.both]
+ inplace_relu_list = [True, False]
+ use_bias_list = [True, False]
+
+ with override_quantized_engine("x86"), torch.no_grad():
+ for inplace_add, conv2d_type, inplace_relu, use_bias in itertools.product(
+ inplace_add_list,
+ conv2d_type_list,
+ inplace_relu_list,
+ use_bias_list,
+ ):
+ m = TestHelperModules.Conv2dAddReLUModule(
+ inplace_add=inplace_add,
+ conv2d_type=conv2d_type,
+ inplace_relu=inplace_relu,
+ use_bias=use_bias
+ ).eval()
+ example_inputs = (torch.randn(2, 3, 16, 16),)
+ quantizer = X86InductorQuantizer().set_global(
+ xiq.get_default_x86_inductor_quantization_config()
+ )
+ if conv2d_type != Conv2DType.both:
+ node_occurrence = {
+ # one for input and weight of the conv
+ # one for output for the relu
+ # one for extra input node of add
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
+ }
+ else:
+ node_occurrence = {
+ # one for input and weight of the conv
+ # one for input and weight of another conv
+ # one for output for the relu
+ # 2 conv will share same input quant/dequant
+ # one for extra input node of add
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
+
+ @skipIfNoX86
+ def test_conv2d_serials_binary_unary_with_quantizer_api(self):
+ """
+ Test pattern of 2 following up conv2d add relu with X86InductorQuantizer.
+ """
+ with override_quantized_engine("x86"), torch.no_grad():
+ m = TestHelperModules.SerialsConv2dAddReLUModule().eval()
+ example_inputs = (torch.randn(2, 3, 16, 16),)
+ quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
+ node_occurrence = {
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
+ torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
+ torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
+ }
+ node_list = [
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.aten.convolution.default,
+ torch.ops.aten.add.Tensor,
+ torch.ops.aten.relu.default,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+ ]
+ self._test_quantizer(
+ m,
+ example_inputs,
+ quantizer,
+ node_occurrence,
+ node_list,
+ )
diff --git a/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py
index 5ba785e..fe345a0 100644
--- a/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py
+++ b/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py
@@ -3,6 +3,7 @@
import copy
import functools
import itertools
+import operator
from .quantizer import (
OperatorConfig,
OperatorPatternType,
@@ -11,6 +12,7 @@
Quantizer,
QuantizationAnnotation,
)
+from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization._pt2e.quantizer.utils import (
get_input_act_qspec,
get_output_act_qspec,
@@ -28,7 +30,10 @@
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from typing import List, Dict, Optional, Set, Any
from torch.fx import Node
-from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
+from torch.fx.passes.utils.source_matcher_utils import (
+ get_source_partitions,
+ SourcePartition,
+)
__all__ = [
"X86InductorQuantizer",
@@ -43,6 +48,24 @@
[F.conv2d],
],
}
+
+ # Append Conv Optional(Add) Optioinal(ReLU)
+ conv_add_relu_options = itertools.product(
+ [torch.nn.Conv2d, F.conv2d],
+ [torch.add, operator.add, None], # add
+ [torch.nn.ReLU, F.relu, None], # relu
+ )
+ for conv_op, add_op, relu_op in conv_add_relu_options:
+ if add_op is None:
+ # Append Conv ReLU
+ supported_operators["conv2d"].append([conv_op, relu_op])
+ elif relu_op is None:
+ # Append Conv Add
+ supported_operators["conv2d"].append([conv_op, add_op])
+ else:
+ # Append Conv Add ReLU
+ supported_operators["conv2d"].append([conv_op, add_op, relu_op])
+
return copy.deepcopy(supported_operators)
@@ -141,6 +164,78 @@
self.operator_type_config[operator_type] = quantization_config
return self
+ def _annotate_conv_node_helper(
+ self,
+ conv_node: torch.fx.Node,
+ annotate_output: bool,
+ quantization_config: QuantizationConfig,
+ ) -> None :
+ """ Helper function to annotate the conv node
+ """
+ input_qspec_map = {}
+ input_node = conv_node.args[0]
+ assert isinstance(input_node, Node)
+ input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+ weight_node = conv_node.args[1]
+ assert isinstance(weight_node, Node)
+ input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+ bias_node = conv_node.args[2]
+ if isinstance(bias_node, Node):
+ input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+ if annotate_output:
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ input_qspec_map=input_qspec_map,
+ # TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
+ output_qspec=get_output_act_qspec(quantization_config),
+ _annotated=True
+ )
+ else:
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ input_qspec_map=input_qspec_map,
+ _annotated=True
+ )
+
+ def _get_output_nodes_of_partitions(
+ self,
+ partition_list: List[SourcePartition],
+ ) -> List[torch.fx.Node]:
+ """ Helper function to get the output node list from partition list
+ """
+ output_node_list = []
+ for partition in partition_list:
+ if len(partition.output_nodes) > 1:
+ raise ValueError("Input partition has more than one output node")
+ output_node = partition.output_nodes[0]
+ assert isinstance(output_node, Node)
+ output_node_list.append(output_node)
+ if len(output_node_list) != len(partition_list):
+ raise ValueError("length of output_node_list should equal to length of partition_list")
+ return output_node_list
+
+ def _get_input_idx_for_binary_node(
+ self,
+ conv_gemm_node: torch.fx.Node,
+ binary_node: torch.fx.Node,
+ ):
+ """ Helper function to check conv_gemm and extra input node index
+ for binary node fused with conv_gemm.
+ """
+ conv_gemm_node_idx = None
+ extra_input_node_idx = None
+ if (binary_node.args[0].op == "call_function") and (
+ binary_node.args[0] == conv_gemm_node
+ ):
+ conv_gemm_node_idx = 0
+ extra_input_node_idx = 1
+ elif (binary_node.args[1].op == "call_function") and (
+ binary_node.args[1] == conv_gemm_node
+ ):
+ conv_gemm_node_idx = 1
+ extra_input_node_idx = 0
+ extra_input_node = binary_node.args[extra_input_node_idx]
+ assert isinstance(extra_input_node, Node)
+ return conv_gemm_node_idx, extra_input_node_idx
+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
""" just handling global spec for now
"""
@@ -155,10 +250,104 @@
# and we will mark the matched node with "_annoated" so fusion operator pattern
# can take precedence over single operator pattern in this way
config = self.global_config
+ self._annotate_conv2d_binary_unary(model, config)
+ self._annotate_conv2d_binary(model, config)
+ self._annotate_conv2d_unary(model, config)
self._annotate_conv2d(model, config)
-
return model
+ def _annotate_conv2d_binary_unary(
+ self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+ ) -> None:
+ # Conv2d + add + unary op
+ fused_partitions = find_sequential_partitions(
+ gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU]
+ )
+ for fused_partition in fused_partitions:
+ conv_partition, binary_partition, unary_partition = fused_partition
+ conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
+ [conv_partition, binary_partition, unary_partition]
+ )
+ conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
+ if (conv_node_idx is None) or (extra_input_node_idx is None):
+ continue
+ if conv_node != binary_node.args[conv_node_idx]:
+ raise ValueError(f"{conv_node} doesn't match input of binary node")
+ extra_input_node = binary_node.args[extra_input_node_idx]
+ if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
+ # No conv node found to be fused with add
+ continue
+ if _is_annotated([unary_node, binary_node, conv_node]):
+ continue
+ self._annotate_conv_node_helper(conv_node, False, quantization_config)
+ binary_node_input_qspec_map = {}
+ binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
+ binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ input_qspec_map=binary_node_input_qspec_map,
+ _annotated=True
+ )
+ unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ # TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
+ output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
+ _annotated=True
+ )
+
+ def _annotate_conv2d_binary(
+ self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+ ) -> None:
+ # Conv2d + add
+ fused_partitions = find_sequential_partitions(
+ gm, [torch.nn.Conv2d, operator.add]
+ )
+ for fused_partition in fused_partitions:
+ conv_partition, binary_partition = fused_partition
+ conv_node, binary_node = self._get_output_nodes_of_partitions(
+ [conv_partition, binary_partition]
+ )
+ conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
+ if (conv_node_idx is None) or (extra_input_node_idx is None):
+ continue
+ if conv_node != binary_node.args[conv_node_idx]:
+ raise ValueError(f"{conv_node} doesn't match input of binary node")
+ extra_input_node = binary_node.args[extra_input_node_idx]
+ assert isinstance(conv_node, Node)
+ if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
+ # No conv node found to be fused with add
+ continue
+ if _is_annotated([binary_node, conv_node]):
+ continue
+ self._annotate_conv_node_helper(conv_node, False, quantization_config)
+ binary_node_input_qspec_map = {}
+ binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
+ binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ input_qspec_map=binary_node_input_qspec_map,
+ # TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
+ output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
+ _annotated=True
+ )
+
+ def _annotate_conv2d_unary(
+ self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+ ) -> None:
+ fused_partitions = find_sequential_partitions(
+ gm, [torch.nn.Conv2d, torch.nn.ReLU]
+ )
+ for fused_partition in fused_partitions:
+ conv_partition, unary_partition = fused_partition
+ conv_node, unary_node = self._get_output_nodes_of_partitions(
+ [conv_partition, unary_partition]
+ )
+ if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
+ continue
+ if _is_annotated([unary_node, conv_node]):
+ continue
+ self._annotate_conv_node_helper(conv_node, False, quantization_config)
+ unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ # TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
+ output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
+ _annotated=True
+ )
+
def _annotate_conv2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
@@ -178,25 +367,7 @@
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue
- input_qspec_map = {}
- input_node = conv_node.args[0]
- assert isinstance(input_node, Node)
- input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
-
- weight_node = conv_node.args[1]
- assert isinstance(weight_node, Node)
- input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
-
- bias_node = conv_node.args[2]
- if isinstance(bias_node, Node):
- input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
-
- conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
- input_qspec_map=input_qspec_map,
- # TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
- output_qspec=get_output_act_qspec(quantization_config),
- _annotated=True
- )
+ self._annotate_conv_node_helper(conv_node, True, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass