[quant][pt2e] prepare_pt2e use quantization spec directly (#102054)

Summary:
In this PR we aligned with the design of annotation API and uses quantization spec directly for annotation.
main change is in prepare, we consume quantization_spec object directly instead of the observer or fake quant constructor, we create the constructor
inside prepare, and annotation api users only need to interact with quantization spec object after this PR

Test Plan:
```
buck2 test mode/opt caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_resnet18_with_quantizer_api (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2EModels)'
```

Reviewed By: kimishpatel

Differential Revision: D45934088

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102054
Approved by: https://github.com/kimishpatel
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index c6a7c05..f401b1c 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -18,6 +18,7 @@
     OperatorConfig,
     QNNPackQuantizer,
     Quantizer,
+    QuantizationSpec,
 )
 from torch.ao.quantization._quantize_pt2e import (
     convert_pt2e,
@@ -71,13 +72,34 @@
                         assert isinstance(weight, Node)
                         bias = node.args[2]
                         assert isinstance(bias, Node)
+                        act_qspec = QuantizationSpec(
+                            dtype=torch.uint8,
+                            quant_min=0,
+                            quant_max=255,
+                            qscheme=torch.per_tensor_affine,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.default_observer,
+                        )
+                        weight_qspec = QuantizationSpec(
+                            dtype=torch.int8,
+                            quant_min=-128,
+                            quant_max=127,
+                            qscheme=torch.per_tensor_affine,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.default_weight_observer,
+                        )
+                        bias_qspec = QuantizationSpec(
+                            dtype=torch.float32,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
+                        )
                         node.meta["quantization_annotation"] = QuantizationAnnotation(
                             input_qspec_map={
-                                input_act: observer.default_observer,
-                                weight: observer.default_weight_observer,
-                                bias: observer.PlaceholderObserver.with_args(dtype=torch.float),
+                                input_act: act_qspec,
+                                weight: weight_qspec,
+                                bias: bias_qspec,
                             },
-                            output_qspec=observer.default_observer,
+                            output_qspec=act_qspec,
                             _annotated=True,
                         )
 
@@ -140,13 +162,34 @@
                         assert isinstance(weight, Node)
                         bias = node.args[2]
                         assert isinstance(bias, Node)
+                        act_qspec = QuantizationSpec(
+                            dtype=torch.uint8,
+                            quant_min=0,
+                            quant_max=255,
+                            qscheme=torch.per_tensor_affine,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.default_observer,
+                        )
+                        weight_qspec = QuantizationSpec(
+                            dtype=torch.int8,
+                            quant_min=-128,
+                            quant_max=127,
+                            qscheme=torch.per_tensor_affine,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.default_weight_observer,
+                        )
+                        bias_qspec = QuantizationSpec(
+                            dtype=torch.float32,
+                            is_dynamic=False,
+                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
+                        )
                         node.meta["quantization_annotation"] = QuantizationAnnotation(
                             input_qspec_map={
-                                input_act: observer.default_observer,
-                                weight: observer.default_weight_observer,
-                                bias: observer.PlaceholderObserver.with_args(dtype=torch.float),
+                                input_act: act_qspec,
+                                weight: weight_qspec,
+                                bias: bias_qspec,
                             },
-                            output_qspec=observer.default_observer,
+                            output_qspec=act_qspec,
                             _annotated=True,
                         )
                     if (
@@ -160,12 +203,12 @@
                         assert isinstance(input_act, Node)
                         maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation(
                             input_qspec_map={
-                                input_act: observer.default_observer,
+                                input_act: act_qspec,
                             },
                             _annotated=True,
                         )
                         getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
-                            output_qspec=observer.default_observer,
+                            output_qspec=act_qspec,
                             _input_output_share_observers=True,
                             _annotated=True,
                         )
diff --git a/torch/ao/quantization/_pt2e/quantizer/__init__.py b/torch/ao/quantization/_pt2e/quantizer/__init__.py
index 93cf803..df388ab 100644
--- a/torch/ao/quantization/_pt2e/quantizer/__init__.py
+++ b/torch/ao/quantization/_pt2e/quantizer/__init__.py
@@ -1,8 +1,14 @@
 from .qnnpack_quantizer import QNNPackQuantizer
-from .quantizer import OperatorConfig, Quantizer, QuantizationAnnotation
+from .quantizer import (
+    OperatorConfig,
+    Quantizer,
+    QuantizationSpec,
+    QuantizationAnnotation,
+)
 
 __all__ = [
     "Quantizer",
+    "QuantizationSpec",
     "QNNPackQuantizer",
     "QuantizationAnnotation",
 ]
diff --git a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
index e82427f..caa2fbf 100644
--- a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
+++ b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
@@ -10,9 +10,9 @@
 import torch.nn.functional as F
 
 from torch.ao.quantization._pt2e.quantizer.utils import (
-    get_act_obs_or_fq_ctr,
-    get_bias_obs_or_fq_ctr,
-    get_weight_obs_or_fq_ctr,
+    get_act_qspec,
+    get_weight_qspec,
+    get_bias_qspec,
 )
 
 from torch.fx import Node
@@ -329,15 +329,15 @@
         input_qspec_map = {}
         input_act = conv_node.args[0]
         assert isinstance(input_act, Node)
-        input_qspec_map[input_act] = get_act_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[input_act] = get_act_qspec(quantization_config)
 
         weight = conv_node.args[1]
         assert isinstance(weight, Node)
-        input_qspec_map[weight] = get_weight_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
 
         bias = conv_node.args[2]
         if isinstance(bias, Node):
-            input_qspec_map[bias] = get_bias_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
 
         conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
@@ -348,7 +348,7 @@
             _annotated=True
         )
         getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),  # type: ignore[arg-type]
+            output_qspec=get_act_qspec(quantization_config),  # type: ignore[arg-type]
             _annotated=True
         )
 
@@ -402,15 +402,15 @@
         input_qspec_map = {}
         input_act = conv_node.args[0]
         assert isinstance(input_act, Node)
-        input_qspec_map[input_act] = get_act_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[input_act] = get_act_qspec(quantization_config)
 
         weight = conv_node.args[1]
         assert isinstance(weight, Node)
-        input_qspec_map[weight] = get_weight_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
 
         bias = conv_node.args[2]
         if isinstance(bias, Node):
-            input_qspec_map[bias] = get_bias_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
 
         conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
@@ -423,7 +423,7 @@
             _annotated=True
         )
         relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),  # type: ignore[arg-type]
+            output_qspec=get_act_qspec(quantization_config),  # type: ignore[arg-type]
             _annotated=True
         )
 
@@ -449,22 +449,22 @@
         input_qspec_map = {}
         input_act = conv_node.args[0]
         assert isinstance(input_act, Node)
-        input_qspec_map[input_act] = get_act_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[input_act] = get_act_qspec(quantization_config)
 
         weight = conv_node.args[1]
         assert isinstance(weight, Node)
-        input_qspec_map[weight] = get_weight_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
 
         bias = conv_node.args[2]
         if isinstance(bias, Node):
-            input_qspec_map[bias] = get_bias_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
 
         conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
             _annotated=True
         )
         relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),  # type: ignore[arg-type]
+            output_qspec=get_act_qspec(quantization_config),  # type: ignore[arg-type]
             _annotated=True
         )
 
@@ -484,19 +484,19 @@
         input_qspec_map = {}
         input_act = conv_node.args[0]
         assert isinstance(input_act, Node)
-        input_qspec_map[input_act] = get_act_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[input_act] = get_act_qspec(quantization_config)
 
         weight = conv_node.args[1]
         assert isinstance(weight, Node)
-        input_qspec_map[weight] = get_weight_obs_or_fq_ctr(quantization_config)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
 
         bias = conv_node.args[2]
         if isinstance(bias, Node):
-            input_qspec_map[bias] = get_bias_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
 
         conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),
+            output_qspec=get_act_qspec(quantization_config),
             _annotated=True
         )
 
@@ -506,6 +506,9 @@
         module_partitions = get_source_partitions(
             gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
         )
+        act_qspec = get_act_qspec(quantization_config)
+        weight_qspec = get_weight_qspec(quantization_config)
+        bias_qspec = get_bias_qspec(quantization_config)
         for module_or_fn_type, partitions in module_partitions.items():
             if module_or_fn_type == torch.nn.Linear:
                 for p in partitions:
@@ -535,20 +538,14 @@
                         _annotate_input_qspec_map(
                             act_use_node,
                             act_node,
-                            get_act_obs_or_fq_ctr(quantization_config),
+                            act_qspec,
                         )
                     if bias_node and _is_annotated([bias_node]) is False:
-                        _annotate_output_qspec(
-                            bias_node, get_bias_obs_or_fq_ctr(quantization_config)
-                        )
+                        _annotate_output_qspec(bias_node, bias_qspec)
                     if _is_annotated([weight_node]) is False:  # type: ignore[list-item]
-                        _annotate_output_qspec(
-                            weight_node, get_weight_obs_or_fq_ctr(quantization_config)
-                        )
+                        _annotate_output_qspec(weight_node, weight_qspec)
                     if _is_annotated([output_node]) is False:
-                        _annotate_output_qspec(
-                            output_node, get_act_obs_or_fq_ctr(quantization_config)
-                        )
+                        _annotate_output_qspec(output_node, act_qspec)
                     nodes_to_mark_annotated = list(p.nodes)
                     _mark_nodes_as_annotated(nodes_to_mark_annotated)
 
@@ -576,14 +573,16 @@
 
         input_act = maxpool_node.args[0]
         assert isinstance(input_act, Node)
+
+        act_qspec = get_act_qspec(quantization_config)
         maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map={
-                input_act: get_act_obs_or_fq_ctr(quantization_config)
+                input_act: act_qspec,
             },
             _annotated=True,
         )
         getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),
+            output_qspec=act_qspec,
             _input_output_share_observers=True,
             _annotated=True,
         )
@@ -605,11 +604,13 @@
 
         input_act = io_obs_sharing_node.args[0]
         assert isinstance(input_act, Node)
+
+        act_qspec = get_act_qspec(quantization_config)
         io_obs_sharing_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map={
-                input_act: get_act_obs_or_fq_ctr(quantization_config)
+                input_act: act_qspec,
             },
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),
+            output_qspec=act_qspec,
             _input_output_share_observers=True,
             _annotated=True,
         )
@@ -657,21 +658,23 @@
         if _is_annotated([relu_node, add_node]):
             return
 
+        act_qspec = get_act_qspec(quantization_config)
+
         input_qspec_map = {}
         input_act0 = add_node.args[0]
         if isinstance(input_act0, Node):
-            input_qspec_map[input_act0] = get_act_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[input_act0] = act_qspec
 
         input_act1 = add_node.args[1]
         if isinstance(input_act1, Node):
-            input_qspec_map[input_act1] = get_act_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[input_act1] = act_qspec
 
         add_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
             _annotated=True,
         )
         relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),
+            output_qspec=act_qspec,
             _annotated=True,
         )
 
@@ -687,18 +690,20 @@
         if _is_annotated([add_node]):
             return
 
+        act_qspec = get_act_qspec(quantization_config)
+
         input_qspec_map = {}
         input_act0 = add_node.args[0]
         if isinstance(input_act0, Node):
-            input_qspec_map[input_act0] = get_act_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[input_act0] = act_qspec
 
         input_act1 = add_node.args[1]
         if isinstance(input_act1, Node):
-            input_qspec_map[input_act1] = get_act_obs_or_fq_ctr(quantization_config)
+            input_qspec_map[input_act1] = act_qspec
 
         add_node.meta["quantization_annotation"] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
-            output_qspec=get_act_obs_or_fq_ctr(quantization_config),
+            output_qspec=act_qspec,
             _annotated=True,
         )
 
diff --git a/torch/ao/quantization/_pt2e/quantizer/quantizer.py b/torch/ao/quantization/_pt2e/quantizer/quantizer.py
index aeb726c..2ce80bc 100644
--- a/torch/ao/quantization/_pt2e/quantizer/quantizer.py
+++ b/torch/ao/quantization/_pt2e/quantizer/quantizer.py
@@ -1,6 +1,5 @@
-import copy
 from abc import ABC, abstractmethod
-from dataclasses import asdict, dataclass, field
+from dataclasses import dataclass, field
 from torch.fx import Node
 from typing import Callable, List, NamedTuple, Optional, Dict, Any
 from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
@@ -9,6 +8,7 @@
 
 __all__ = [
     "Quantizer",
+    "QuantizationSpec",
     "QuantizationAnnotation",
 ]
 
@@ -22,17 +22,6 @@
     torch.per_channel_affine_float_qparams,
 ]
 
-# TODO: add support for torch dtype in quant code base
-# this includes observers and prepare/convert code
-_TORCH_DTYPE_TO_QDTYPE = {
-    torch.int8: torch.qint8,
-    torch.uint8: torch.quint8,
-    torch.int32: torch.qint32,
-    torch.float16: torch.float16,
-    torch.float32: torch.float32,
-}
-
-
 @dataclass(eq=True, frozen=True)
 class QuantizationSpec:
     dtype: torch.dtype
@@ -72,12 +61,6 @@
             raise ValueError("Ch_axis is < 0.")
 
 
-def get_observer_kwargs(quant_spec: QuantizationSpec):
-    kwargs_dict = asdict(quant_spec)
-    kwargs_dict["dtype"] = _TORCH_DTYPE_TO_QDTYPE[quant_spec.dtype]
-    return copy.deepcopy(kwargs_dict)
-
-
 # In the absence of better name, just winging it with QuantizationConfig
 @dataclass(eq=True, frozen=True)
 class QuantizationConfig:
diff --git a/torch/ao/quantization/_pt2e/quantizer/utils.py b/torch/ao/quantization/_pt2e/quantizer/utils.py
index 41eba43..a0d3b20 100644
--- a/torch/ao/quantization/_pt2e/quantizer/utils.py
+++ b/torch/ao/quantization/_pt2e/quantizer/utils.py
@@ -1,34 +1,10 @@
 import torch
 from torch.ao.quantization._pt2e.quantizer.quantizer import (
-    get_observer_kwargs,
     QuantizationConfig,
     QuantizationSpec,
 )
-from torch.ao.quantization.observer import (
-    _PartialWrapper,
-    PlaceholderObserver,
-)
-from torch.ao.quantization.qconfig import _obs_or_fq_ctr_equals
 
-def create_observer(quantization_spec: QuantizationSpec, **extra_kwargs):
-    if quantization_spec is None:
-        return None
-    observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
-    kwargs = get_observer_kwargs(quantization_spec)
-    kwargs.pop("observer_or_fake_quant_ctr")
-    # we will remove is_dynamic from QuantizationSpec because
-    # it seems that dynamic range quantization
-    if not _obs_or_fq_ctr_equals(observer_or_fake_quant_ctr, PlaceholderObserver):
-        kwargs.pop("is_dynamic")
-    obs_or_fq_class = observer_or_fake_quant_ctr
-    if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
-        obs_or_fq_class = observer_or_fake_quant_ctr.p.func  # type: ignore[union-attr, assignment]
-    if "PerChannel" not in obs_or_fq_class.__name__:  # type: ignore[operator, union-attr]
-        kwargs.pop("ch_axis")
-    return observer_or_fake_quant_ctr.with_args(**kwargs, **extra_kwargs)
-
-
-def get_act_obs_or_fq_ctr(quantization_config: QuantizationConfig):
+def get_act_qspec(quantization_config: QuantizationConfig):
     if quantization_config is None:
         return None
     if quantization_config.activation is None:
@@ -43,9 +19,9 @@
         raise Exception(
             "Unsupported quantization_spec for activation: {}".format(quantization_spec)
         )
-    return create_observer(quantization_spec)
+    return quantization_spec
 
-def get_weight_obs_or_fq_ctr(quantization_config: QuantizationConfig):
+def get_weight_qspec(quantization_config: QuantizationConfig):
     if quantization_config is None:
         return None
     assert quantization_config is not None
@@ -59,9 +35,9 @@
         raise ValueError(
             f"Unsupported quantization_spec {quantization_spec} for weight"
         )
-    return create_observer(quantization_spec)
+    return quantization_spec
 
-def get_bias_obs_or_fq_ctr(quantization_config: QuantizationConfig):
+def get_bias_qspec(quantization_config: QuantizationConfig):
     if quantization_config is None:
         return None
     assert quantization_config is not None
@@ -71,4 +47,4 @@
     assert (
         quantization_spec.dtype == torch.float
     ), "Only float dtype for bias is supported for bias right now"
-    return create_observer(quantization_spec)
+    return quantization_spec
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 03aa840..4828086 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -15,10 +15,12 @@
 )
 from ..observer import (
     ObserverBase,
-    _is_activation_post_process
+    _is_activation_post_process,
+    _PartialWrapper,
 )
 from ..qconfig import (
     _is_reuse_input_qconfig,
+    _obs_or_fq_ctr_equals,
     QConfigAny,
 )
 from ..qconfig_mapping import (
@@ -99,11 +101,12 @@
     PrepareCustomConfig,
     StandaloneModuleConfigEntry,
 )
+from torch.ao.quantization._pt2e.quantizer import QuantizationSpec
 
 from torch._subclasses import FakeTensor
 
 from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, Callable
-
+from dataclasses import asdict
 
 __all__ = [
     "insert_observers_for_model",
@@ -130,6 +133,40 @@
     "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
 }
 
+# TODO: add support for torch dtype in quant code base
+# this includes observers and prepare/convert code
+_TORCH_DTYPE_TO_QDTYPE = {
+    torch.int8: torch.qint8,
+    torch.uint8: torch.quint8,
+    torch.int32: torch.qint32,
+    torch.float16: torch.float16,
+    torch.float32: torch.float32,
+}
+
+def _get_observer_kwargs(quant_spec: QuantizationSpec):
+    kwargs_dict = asdict(quant_spec)
+    kwargs_dict["dtype"] = _TORCH_DTYPE_TO_QDTYPE[quant_spec.dtype]
+    return copy.deepcopy(kwargs_dict)
+
+def _create_obs_or_fq_ctr_from_qspec(quantization_spec: QuantizationSpec, **extra_kwargs):
+    """ Create observer or fake quantize constructors based on quantization spec
+    """
+    if quantization_spec is None:
+        return None
+    observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
+    kwargs = _get_observer_kwargs(quantization_spec)
+    kwargs.pop("observer_or_fake_quant_ctr")
+    # we will remove is_dynamic from QuantizationSpec because
+    # it seems that dynamic range quantization
+    if not _obs_or_fq_ctr_equals(observer_or_fake_quant_ctr, PlaceholderObserver):
+        kwargs.pop("is_dynamic")
+    obs_or_fq_class = observer_or_fake_quant_ctr
+    if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
+        obs_or_fq_class = observer_or_fake_quant_ctr.p.func  # type: ignore[union-attr, assignment]
+    if "PerChannel" not in obs_or_fq_class.__name__:  # type: ignore[operator, union-attr]
+        kwargs.pop("ch_axis")
+    return observer_or_fake_quant_ctr.with_args(**kwargs, **extra_kwargs)
+
 def _needs_obs_or_fq(
         prev_output_dtype: Any,
         prev_output_is_dynamic: bool,
@@ -524,7 +561,7 @@
     """
     assert isinstance(arg, Node)
     if "quantization_annotation" in arg.meta:
-        return arg.meta["quantization_annotation"].output_qspec
+        return _create_obs_or_fq_ctr_from_qspec(arg.meta["quantization_annotation"].output_qspec)
 
     # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
     # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
@@ -539,7 +576,7 @@
         observed_arg = arg.args[0]
         assert isinstance(observed_arg, Node), "Currently we only support observing Node"
         if "quantization_annotation" in observed_arg.meta:
-            output_act_obs_or_fq_ctr = observed_arg.meta["quantization_annotation"].output_qspec
+            output_act_obs_or_fq_ctr = _create_obs_or_fq_ctr_from_qspec(observed_arg.meta["quantization_annotation"].output_qspec)
         else:
             assert "target_dtype_info" in observed_arg.meta
             output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
@@ -578,8 +615,10 @@
     # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
     #
     if "quantization_annotation" in node.meta:
-        input_act_obs_or_fq_ctr = \
-            node.meta["quantization_annotation"].input_qspec_map.get(arg, _DEFAULT_FP32_OBS_OR_FQ_CTR)
+        input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
+        input_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
+        if arg in input_qspec_map:
+            input_act_obs_or_fq_ctr = _create_obs_or_fq_ctr_from_qspec(input_qspec_map[arg])
         return input_act_obs_or_fq_ctr
 
     # we can remove the following path in the future if fx graph mode quantization is
@@ -839,7 +878,7 @@
 
     is_standalone_module = False
     if "quantization_annotation" in node.meta:
-        output_act_obs_or_fq_ctr = node.meta["quantization_annotation"].output_qspec
+        output_act_obs_or_fq_ctr = _create_obs_or_fq_ctr_from_qspec(node.meta["quantization_annotation"].output_qspec)
     else:
         assert "target_dtype_info" in node.meta
         is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False)