Revert "[quant][pt2e][xnnpack_quantizer] Add support for mul and mul_relu (#107930)"
This reverts commit 1d1739dc6d7365c28719cd0175081f9d9aab0324.
Reverted https://github.com/pytorch/pytorch/pull/107930 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107930#issuecomment-1694069330))
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index dc1a6dd..23dc5a2 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -227,17 +227,6 @@
conv_out = torch.squeeze(conv_out, dim=0)
return self.linear(conv_out)
- class AddInplaceAdd(torch.nn.Module):
- def forward(self, x, y):
- x = x + y
- x += y
- return x
-
- class MulInplaceMul(torch.nn.Module):
- def forward(self, x, y):
- x = x * y
- x *= y
- return x
class PT2EQuantizationTestCase(QuantizationTestCase):
"""
@@ -1234,60 +1223,6 @@
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
- def test_add_and_inplace_add(self):
- quantizer = XNNPACKQuantizer()
- quantization_config = get_symmetric_quantization_config(is_per_channel=True)
- quantizer.set_global(quantization_config)
- example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
- node_occurrence = {
- # two input and one output for first add, and output for second add
- torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
- }
- node_list = [
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.aten.add.Tensor,
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.aten.add_.Tensor,
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
- ]
- self._test_quantizer(
- TestHelperModules.AddInplaceAdd(),
- example_inputs,
- quantizer,
- node_occurrence,
- node_list,
- )
-
- def test_mul_and_inplace_mul(self):
- quantizer = XNNPACKQuantizer()
- quantization_config = get_symmetric_quantization_config(is_per_channel=True)
- quantizer.set_global(quantization_config)
- example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
- node_occurrence = {
- # two input and one output for first add, and output for second add
- torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
- }
- node_list = [
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.aten.mul.Tensor,
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default,
- torch.ops.aten.mul_.Tensor,
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
- ]
- self._test_quantizer(
- TestHelperModules.MulInplaceMul(),
- example_inputs,
- quantizer,
- node_occurrence,
- node_list,
- )
-
def test_xnnpack_quantizer_conv(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py
index 6115ced..dad00f2 100644
--- a/test/quantization/pt2e/test_x86inductor_quantizer.py
+++ b/test/quantization/pt2e/test_x86inductor_quantizer.py
@@ -290,8 +290,8 @@
# one for output for the add
# 2 conv will share same input quant/dequant
# one for extra input node of add
- torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
@@ -344,8 +344,8 @@
# one for output for the relu
# 2 conv will share same input quant/dequant
# one for extra input node of add
- torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
- torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+ torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py
index e9b31e8..613eeb3 100644
--- a/torch/ao/quantization/pt2e/graph_utils.py
+++ b/torch/ao/quantization/pt2e/graph_utils.py
@@ -24,7 +24,6 @@
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
{torch.add, operator.add, operator.iadd, "add", "add_"},
- {torch.mul, operator.mul, operator.imul},
]
diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py
index ac59fc3..c76365c 100644
--- a/torch/ao/quantization/pt2e/prepare.py
+++ b/torch/ao/quantization/pt2e/prepare.py
@@ -88,38 +88,11 @@
new_arg = arg
obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
else:
- # skip inserting new observers if there is an observer inserted for the arg before
- # that has the same dtype that we want to insert here
- # alternatively we could have a dedup pass after we insert all observers to deduplicate
- # observers
- # Example:
- # arg -> existing_obs -> conv1
- # \ -> conv2
- #
- # instead of inserting new observers we will have:
- # arg -> existing_obs -> conv1
- # \ -> conv2
- existing_obs_node = None
- for maybe_obs_node in arg.users.keys():
- if maybe_obs_node.op == 'call_module':
- maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
- if (
- type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
- maybe_obs_mod.dtype == arg_as_input_target_dtype
- ):
- arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
- existing_obs_node = maybe_obs_node
- break
-
assert arg_as_input_act_obs_or_fq is not None
+ new_obs_node = _insert_obs_or_fq(
+ arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) # type: ignore[arg-type]
+ new_arg = new_obs_node
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
- if existing_obs_node is None:
- new_obs_node = _insert_obs_or_fq(
- arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
- # override this arg to be the observed arg
- new_arg = new_obs_node
- else:
- new_arg = existing_obs_node
return new_arg
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py
index ccc20eb..17d66ee 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py
@@ -329,8 +329,6 @@
self._annotate_conv2d_patterns(model, config, filter_fn)
self._annotate_max_pool2d(model, config, filter_fn)
self._annotate_add_patterns(model, config, filter_fn)
- OP_TO_ANNOTATOR["mul_relu"](model, config, filter_fn)
- OP_TO_ANNOTATOR["mul"](model, config, filter_fn)
self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
self._annotate_gru_io_only(model, config, filter_fn)
return model
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
index d921160..a920d5e 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -521,28 +521,20 @@
)
-def _annotate_adaptive_avg_pool2d(
+def _annotate_input_out_obs_sharing_op(
+ op: Callable,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
- """Always annotate adaptive_avg_pool2d op"""
- module_partitions = get_source_partitions(
- gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
- )
+ module_partitions = get_source_partitions(gm.graph, [op], filter_fn)
partitions = list(itertools.chain(*module_partitions.values()))
for partition in partitions:
- pool_node = partition.output_nodes[0]
- if (
- pool_node.op != "call_function"
- or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
- ):
- raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
-
- if _is_annotated([pool_node]):
+ io_obs_sharing_node = partition.output_nodes[0]
+ if _is_annotated([io_obs_sharing_node]):
continue
- input_act = pool_node.args[0]
+ input_act = io_obs_sharing_node.args[0]
assert isinstance(input_act, Node)
# only annotate input output sharing operator
@@ -552,21 +544,31 @@
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
- input_act_qspec = get_input_act_qspec(quantization_config)
- else:
- input_act_qspec = SharedQuantizationSpec(input_act)
+ continue
- # output sharing with input
- output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
- pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
+ act_qspec = SharedQuantizationSpec(input_act)
+ io_obs_sharing_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
- input_act: input_act_qspec,
+ input_act: act_qspec,
},
- output_qspec=output_act_qspec,
+ output_qspec=act_qspec,
_annotated=True,
)
+def _annotate_adaptive_avg_pool2d(
+ gm: torch.fx.GraphModule,
+ quantization_config: Optional[QuantizationConfig],
+ filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> None:
+ _annotate_input_out_obs_sharing_op(
+ torch.nn.AdaptiveAvgPool2d, gm, quantization_config, filter_fn
+ )
+ _annotate_input_out_obs_sharing_op(
+ F.adaptive_avg_pool2d, gm, quantization_config, filter_fn
+ )
+
+
def _annotate_add_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
@@ -615,7 +617,7 @@
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
add_partitions = get_source_partitions(
- gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
+ gm.graph, [operator.add, torch.add], filter_fn
)
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
@@ -642,81 +644,6 @@
)
-def _annotate_mul_relu(
- gm: torch.fx.GraphModule,
- quantization_config: Optional[QuantizationConfig],
- filter_fn: Optional[Callable[[Node], bool]] = None,
-) -> None:
- fused_partitions = find_sequential_partitions(
- gm, [torch.mul, torch.nn.ReLU], filter_fn
- )
- for fused_partition in fused_partitions:
- mul_partition, relu_partition = fused_partition
- if len(relu_partition.output_nodes) > 1:
- raise ValueError("Relu partition has more than one output node")
- relu_node = relu_partition.output_nodes[0]
- if len(mul_partition.output_nodes) > 1:
- raise ValueError("mul partition has more than one output node")
- mul_node = mul_partition.output_nodes[0]
-
- if _is_annotated([relu_node, mul_node]):
- continue
-
- input_act_qspec = get_input_act_qspec(quantization_config)
- output_act_qspec = get_output_act_qspec(quantization_config)
-
- input_qspec_map = {}
- input_act0 = mul_node.args[0]
- if isinstance(input_act0, Node):
- input_qspec_map[input_act0] = input_act_qspec
-
- input_act1 = mul_node.args[1]
- if isinstance(input_act1, Node):
- input_qspec_map[input_act1] = input_act_qspec
-
- mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
- input_qspec_map=input_qspec_map,
- _annotated=True,
- )
- relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
- output_qspec=output_act_qspec,
- _annotated=True,
- )
-
-
-def _annotate_mul(
- gm: torch.fx.GraphModule,
- quantization_config: Optional[QuantizationConfig],
- filter_fn: Optional[Callable[[Node], bool]] = None,
-) -> None:
- mul_partitions = get_source_partitions(
- gm.graph, [operator.mul, torch.mul, operator.imul], filter_fn
- )
- mul_partitions = list(itertools.chain(*mul_partitions.values()))
- for mul_partition in mul_partitions:
- mul_node = mul_partition.output_nodes[0]
- if _is_annotated([mul_node]):
- continue
-
- input_act_qspec = get_input_act_qspec(quantization_config)
- output_act_qspec = get_output_act_qspec(quantization_config)
-
- input_qspec_map = {}
- input_act0 = mul_node.args[0]
- if isinstance(input_act0, Node):
- input_qspec_map[input_act0] = input_act_qspec
-
- input_act1 = mul_node.args[1]
- if isinstance(input_act1, Node):
- input_qspec_map[input_act1] = input_act_qspec
-
- mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
- input_qspec_map=input_qspec_map,
- output_qspec=output_act_qspec,
- _annotated=True,
- )
-
-
OP_TO_ANNOTATOR = {
"linear": _annotate_linear,
"conv2d": _annotate_conv2d,
@@ -726,8 +653,6 @@
"max_pool2d": _annotate_max_pool2d,
"add": _annotate_add,
"add_relu": _annotate_add_relu,
- "mul": _annotate_mul,
- "mul_relu": _annotate_mul_relu,
"adaptive_avg_pool2d": _annotate_adaptive_avg_pool2d,
# input output only gru
"gru_io_only": _annotate_gru_io_only,