Revert "[quant][pt2e][xnnpack_quantizer] Add support for mul and mul_relu (#107930)"

This reverts commit 1d1739dc6d7365c28719cd0175081f9d9aab0324.

Reverted https://github.com/pytorch/pytorch/pull/107930 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107930#issuecomment-1694069330))
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index dc1a6dd..23dc5a2 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -227,17 +227,6 @@
             conv_out = torch.squeeze(conv_out, dim=0)
             return self.linear(conv_out)
 
-    class AddInplaceAdd(torch.nn.Module):
-        def forward(self, x, y):
-            x = x + y
-            x += y
-            return x
-
-    class MulInplaceMul(torch.nn.Module):
-        def forward(self, x, y):
-            x = x * y
-            x *= y
-            return x
 
 class PT2EQuantizationTestCase(QuantizationTestCase):
     """
@@ -1234,60 +1223,6 @@
             m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
         )
 
-    def test_add_and_inplace_add(self):
-        quantizer = XNNPACKQuantizer()
-        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
-        quantizer.set_global(quantization_config)
-        example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
-        node_occurrence = {
-            # two input and one output for first add, and output for second add
-            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
-        }
-        node_list = [
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.aten.add.Tensor,
-            torch.ops.quantized_decomposed.quantize_per_tensor.default,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.aten.add_.Tensor,
-            torch.ops.quantized_decomposed.quantize_per_tensor.default,
-        ]
-        self._test_quantizer(
-            TestHelperModules.AddInplaceAdd(),
-            example_inputs,
-            quantizer,
-            node_occurrence,
-            node_list,
-        )
-
-    def test_mul_and_inplace_mul(self):
-        quantizer = XNNPACKQuantizer()
-        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
-        quantizer.set_global(quantization_config)
-        example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
-        node_occurrence = {
-            # two input and one output for first add, and output for second add
-            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
-        }
-        node_list = [
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.aten.mul.Tensor,
-            torch.ops.quantized_decomposed.quantize_per_tensor.default,
-            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
-            torch.ops.aten.mul_.Tensor,
-            torch.ops.quantized_decomposed.quantize_per_tensor.default,
-        ]
-        self._test_quantizer(
-            TestHelperModules.MulInplaceMul(),
-            example_inputs,
-            quantizer,
-            node_occurrence,
-            node_list,
-        )
-
     def test_xnnpack_quantizer_conv(self):
         quantizer = XNNPACKQuantizer()
         quantization_config = get_symmetric_quantization_config(is_per_channel=True)
diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py
index 6115ced..dad00f2 100644
--- a/test/quantization/pt2e/test_x86inductor_quantizer.py
+++ b/test/quantization/pt2e/test_x86inductor_quantizer.py
@@ -290,8 +290,8 @@
                         # one for output for the add
                         # 2 conv will share same input quant/dequant
                         # one for extra input node of add
-                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
-                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
                         torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
                         torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
                     }
@@ -344,8 +344,8 @@
                         # one for output for the relu
                         # 2 conv will share same input quant/dequant
                         # one for extra input node of add
-                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
-                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
+                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
+                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
                         torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
                         torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
                     }
diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py
index e9b31e8..613eeb3 100644
--- a/torch/ao/quantization/pt2e/graph_utils.py
+++ b/torch/ao/quantization/pt2e/graph_utils.py
@@ -24,7 +24,6 @@
     {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
     {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
     {torch.add, operator.add, operator.iadd, "add", "add_"},
-    {torch.mul, operator.mul, operator.imul},
 ]
 
 
diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py
index ac59fc3..c76365c 100644
--- a/torch/ao/quantization/pt2e/prepare.py
+++ b/torch/ao/quantization/pt2e/prepare.py
@@ -88,38 +88,11 @@
             new_arg = arg
             obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
         else:
-            # skip inserting new observers if there is an observer inserted for the arg before
-            # that has the same dtype that we want to insert here
-            # alternatively we could have a dedup pass after we insert all observers to deduplicate
-            # observers
-            # Example:
-            # arg -> existing_obs -> conv1
-            #    \ -> conv2
-            #
-            # instead of inserting new observers we will have:
-            # arg -> existing_obs -> conv1
-            #                   \ -> conv2
-            existing_obs_node = None
-            for maybe_obs_node in arg.users.keys():
-                if maybe_obs_node.op == 'call_module':
-                    maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
-                    if (
-                        type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
-                        maybe_obs_mod.dtype == arg_as_input_target_dtype
-                    ):
-                        arg_as_input_act_obs_or_fq = maybe_obs_mod  # type: ignore[assignment]
-                        existing_obs_node = maybe_obs_node
-                        break
-
             assert arg_as_input_act_obs_or_fq is not None
+            new_obs_node = _insert_obs_or_fq(
+                arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)  # type: ignore[arg-type]
+            new_arg = new_obs_node
             obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
-            if existing_obs_node is None:
-                new_obs_node = _insert_obs_or_fq(
-                    arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
-                # override this arg to be the observed arg
-                new_arg = new_obs_node
-            else:
-                new_arg = existing_obs_node
 
     return new_arg
 
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py
index ccc20eb..17d66ee 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py
@@ -329,8 +329,6 @@
         self._annotate_conv2d_patterns(model, config, filter_fn)
         self._annotate_max_pool2d(model, config, filter_fn)
         self._annotate_add_patterns(model, config, filter_fn)
-        OP_TO_ANNOTATOR["mul_relu"](model, config, filter_fn)
-        OP_TO_ANNOTATOR["mul"](model, config, filter_fn)
         self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
         self._annotate_gru_io_only(model, config, filter_fn)
         return model
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
index d921160..a920d5e 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -521,28 +521,20 @@
         )
 
 
-def _annotate_adaptive_avg_pool2d(
+def _annotate_input_out_obs_sharing_op(
+    op: Callable,
     gm: torch.fx.GraphModule,
     quantization_config: Optional[QuantizationConfig],
     filter_fn: Optional[Callable[[Node], bool]] = None,
 ) -> None:
-    """Always annotate adaptive_avg_pool2d op"""
-    module_partitions = get_source_partitions(
-        gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
-    )
+    module_partitions = get_source_partitions(gm.graph, [op], filter_fn)
     partitions = list(itertools.chain(*module_partitions.values()))
     for partition in partitions:
-        pool_node = partition.output_nodes[0]
-        if (
-            pool_node.op != "call_function"
-            or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
-        ):
-            raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
-
-        if _is_annotated([pool_node]):
+        io_obs_sharing_node = partition.output_nodes[0]
+        if _is_annotated([io_obs_sharing_node]):
             continue
 
-        input_act = pool_node.args[0]
+        input_act = io_obs_sharing_node.args[0]
         assert isinstance(input_act, Node)
 
         # only annotate input output sharing operator
@@ -552,21 +544,31 @@
             or not input_act.meta["quantization_annotation"]._annotated
             or input_act.meta["quantization_annotation"].output_qspec is None
         ):
-            input_act_qspec = get_input_act_qspec(quantization_config)
-        else:
-            input_act_qspec = SharedQuantizationSpec(input_act)
+            continue
 
-        # output sharing with input
-        output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
-        pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
+        act_qspec = SharedQuantizationSpec(input_act)
+        io_obs_sharing_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map={
-                input_act: input_act_qspec,
+                input_act: act_qspec,
             },
-            output_qspec=output_act_qspec,
+            output_qspec=act_qspec,
             _annotated=True,
         )
 
 
+def _annotate_adaptive_avg_pool2d(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> None:
+    _annotate_input_out_obs_sharing_op(
+        torch.nn.AdaptiveAvgPool2d, gm, quantization_config, filter_fn
+    )
+    _annotate_input_out_obs_sharing_op(
+        F.adaptive_avg_pool2d, gm, quantization_config, filter_fn
+    )
+
+
 def _annotate_add_relu(
     gm: torch.fx.GraphModule,
     quantization_config: Optional[QuantizationConfig],
@@ -615,7 +617,7 @@
     filter_fn: Optional[Callable[[Node], bool]] = None,
 ) -> None:
     add_partitions = get_source_partitions(
-        gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
+        gm.graph, [operator.add, torch.add], filter_fn
     )
     add_partitions = list(itertools.chain(*add_partitions.values()))
     for add_partition in add_partitions:
@@ -642,81 +644,6 @@
         )
 
 
-def _annotate_mul_relu(
-    gm: torch.fx.GraphModule,
-    quantization_config: Optional[QuantizationConfig],
-    filter_fn: Optional[Callable[[Node], bool]] = None,
-) -> None:
-    fused_partitions = find_sequential_partitions(
-        gm, [torch.mul, torch.nn.ReLU], filter_fn
-    )
-    for fused_partition in fused_partitions:
-        mul_partition, relu_partition = fused_partition
-        if len(relu_partition.output_nodes) > 1:
-            raise ValueError("Relu partition has more than one output node")
-        relu_node = relu_partition.output_nodes[0]
-        if len(mul_partition.output_nodes) > 1:
-            raise ValueError("mul partition has more than one output node")
-        mul_node = mul_partition.output_nodes[0]
-
-        if _is_annotated([relu_node, mul_node]):
-            continue
-
-        input_act_qspec = get_input_act_qspec(quantization_config)
-        output_act_qspec = get_output_act_qspec(quantization_config)
-
-        input_qspec_map = {}
-        input_act0 = mul_node.args[0]
-        if isinstance(input_act0, Node):
-            input_qspec_map[input_act0] = input_act_qspec
-
-        input_act1 = mul_node.args[1]
-        if isinstance(input_act1, Node):
-            input_qspec_map[input_act1] = input_act_qspec
-
-        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            input_qspec_map=input_qspec_map,
-            _annotated=True,
-        )
-        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=output_act_qspec,
-            _annotated=True,
-        )
-
-
-def _annotate_mul(
-    gm: torch.fx.GraphModule,
-    quantization_config: Optional[QuantizationConfig],
-    filter_fn: Optional[Callable[[Node], bool]] = None,
-) -> None:
-    mul_partitions = get_source_partitions(
-        gm.graph, [operator.mul, torch.mul, operator.imul], filter_fn
-    )
-    mul_partitions = list(itertools.chain(*mul_partitions.values()))
-    for mul_partition in mul_partitions:
-        mul_node = mul_partition.output_nodes[0]
-        if _is_annotated([mul_node]):
-            continue
-
-        input_act_qspec = get_input_act_qspec(quantization_config)
-        output_act_qspec = get_output_act_qspec(quantization_config)
-
-        input_qspec_map = {}
-        input_act0 = mul_node.args[0]
-        if isinstance(input_act0, Node):
-            input_qspec_map[input_act0] = input_act_qspec
-
-        input_act1 = mul_node.args[1]
-        if isinstance(input_act1, Node):
-            input_qspec_map[input_act1] = input_act_qspec
-
-        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            input_qspec_map=input_qspec_map,
-            output_qspec=output_act_qspec,
-            _annotated=True,
-        )
-
-
 OP_TO_ANNOTATOR = {
     "linear": _annotate_linear,
     "conv2d": _annotate_conv2d,
@@ -726,8 +653,6 @@
     "max_pool2d": _annotate_max_pool2d,
     "add": _annotate_add,
     "add_relu": _annotate_add_relu,
-    "mul": _annotate_mul,
-    "mul_relu": _annotate_mul_relu,
     "adaptive_avg_pool2d": _annotate_adaptive_avg_pool2d,
     # input output only gru
     "gru_io_only": _annotate_gru_io_only,