[quant][fx][refactor] Move ObservationType to backend_config.py (#83368)
Summary:
Now we have a separate file to define BackendConfig related classes, we can move ObservationType to that file as well
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83368
Approved by: https://github.com/andrewor14
diff --git a/torch/ao/quantization/backend_config/__init__.py b/torch/ao/quantization/backend_config/__init__.py
index 5a985b6..3eab482 100644
--- a/torch/ao/quantization/backend_config/__init__.py
+++ b/torch/ao/quantization/backend_config/__init__.py
@@ -1,6 +1,5 @@
-from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
+from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
from .native import get_native_backend_config, get_native_backend_config_dict
-from .observation_type import ObservationType
from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
__all__ = [
@@ -8,4 +7,8 @@
"get_native_backend_config_dict",
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
+ "BackendConfig",
+ "BackendPatternConfig",
+ "DTypeConfig",
+ "ObservationType",
]
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index 1855ae5..ea103f3 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -8,8 +8,11 @@
import torch.nn.quantized._reference as nnqr
from collections import namedtuple
from typing import List
-from .observation_type import ObservationType
-from .backend_config import BackendPatternConfig, DTypeConfig
+from .backend_config import (
+ BackendPatternConfig,
+ DTypeConfig,
+ ObservationType,
+)
from ..fuser_method_mappings import (
reverse_sequential_wrapper2,
reverse2,
diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py
index 0ab3c27..223fc5a 100644
--- a/torch/ao/quantization/backend_config/backend_config.py
+++ b/torch/ao/quantization/backend_config/backend_config.py
@@ -3,15 +3,16 @@
from typing import Any, Callable, Dict, List, Optional, Type
import torch
-from torch.ao.quantization.backend_config.observation_type import ObservationType
from torch.ao.quantization.observer import _PartialWrapper
from torch.ao.quantization.utils import Pattern
+from enum import Enum
__all__ = [
"BackendConfig",
"BackendPatternConfig",
"DTypeConfig",
+ "ObservationType",
]
@@ -43,6 +44,17 @@
OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize"
OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer"
+# TODO: maybe rename this to something that's not related to observer
+# e.g. QParamsType
+class ObservationType(Enum):
+ # this means input and output are observed with different observers, based
+ # on qconfig.activation
+ # example: conv, linear, softmax
+ OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
+ # this means the output will use the same observer instance as input, based
+ # on qconfig.activation
+ # example: torch.cat, maxpool
+ OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
@dataclass
class DTypeConfig:
diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py
index 80902b9..704fef0 100644
--- a/torch/ao/quantization/backend_config/native.py
+++ b/torch/ao/quantization/backend_config/native.py
@@ -10,8 +10,12 @@
_get_conv_configs,
_get_share_qparams_op_configs,
)
-from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
-from .observation_type import ObservationType
+from .backend_config import (
+ BackendConfig,
+ BackendPatternConfig,
+ DTypeConfig,
+ ObservationType
+)
from ..fake_quantize import FixedQParamsFakeQuantize
from ..fuser_method_mappings import (
reverse_sequential_wrapper2,
diff --git a/torch/ao/quantization/backend_config/observation_type.py b/torch/ao/quantization/backend_config/observation_type.py
index 9a25f1d..e69de29 100644
--- a/torch/ao/quantization/backend_config/observation_type.py
+++ b/torch/ao/quantization/backend_config/observation_type.py
@@ -1,13 +0,0 @@
-from enum import Enum
-
-__all__ = ['ObservationType']
-
-class ObservationType(Enum):
- # this means input and output are observed with different observers, based
- # on qconfig.activation
- # example: conv, linear, softmax
- OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
- # this means the output will use the same observer instance as input, based
- # on qconfig.activation
- # example: torch.cat, maxpool
- OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
diff --git a/torch/ao/quantization/backend_config/tensorrt.py b/torch/ao/quantization/backend_config/tensorrt.py
index f709e63..9b6fb39 100644
--- a/torch/ao/quantization/backend_config/tensorrt.py
+++ b/torch/ao/quantization/backend_config/tensorrt.py
@@ -1,6 +1,10 @@
import torch
-from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
-from .observation_type import ObservationType
+from .backend_config import (
+ BackendConfig,
+ BackendPatternConfig,
+ DTypeConfig,
+ ObservationType
+)
from ._common_operator_config_utils import (
_get_binary_op_configs,
_get_linear_configs,
diff --git a/torch/ao/quantization/fx/backend_config_utils.py b/torch/ao/quantization/fx/backend_config_utils.py
index 4771479..00a3f03 100644
--- a/torch/ao/quantization/fx/backend_config_utils.py
+++ b/torch/ao/quantization/fx/backend_config_utils.py
@@ -1,7 +1,9 @@
import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
-from torch.ao.quantization.backend_config import get_native_backend_config
-from torch.ao.quantization.backend_config.observation_type import ObservationType
+from torch.ao.quantization.backend_config import (
+ get_native_backend_config,
+ ObservationType,
+)
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,