[quant][fx] Remove Standalone and CustomModule QuantizeHandler type checks in prepare (#75202)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75202
Instead of checking the type we use a method in the QuantizeHandler to check if a module
is a standalone or custom module, not user facing
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D35379641
fbshipit-source-id: c2f970c7e27f74793fa67f8fd5a16a43525e35aa
(cherry picked from commit 251500f06359c9046dd9067543cc80be24ddee33)
diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py
index 38fb6e8..a1217ec 100644
--- a/torch/ao/quantization/fx/match_utils.py
+++ b/torch/ao/quantization/fx/match_utils.py
@@ -7,8 +7,6 @@
from .quantization_types import Pattern
from .quantization_patterns import (
QuantizeHandler,
- CustomModuleQuantizeHandler,
- StandaloneModuleQuantizeHandler,
)
from ..qconfig import (
QConfigAny,
@@ -198,7 +196,7 @@
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
- node, node, None, CustomModuleQuantizeHandler(node, modules),
+ node, node, None, QuantizeHandler(node, modules, is_custom_module=True),
custom_module_qconfig)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
@@ -214,10 +212,10 @@
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
- custom_module_qconfig = qconfig_map[node.name]
+ standalone_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, node, None,
- StandaloneModuleQuantizeHandler(node, modules),
- custom_module_qconfig)
+ QuantizeHandler(node, modules, is_standalone_module=True),
+ standalone_module_qconfig)
return match_map
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 72fcdcd..d5fee7a 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -30,8 +30,6 @@
from .quantization_patterns import (
QuantizeHandler,
- CustomModuleQuantizeHandler,
- StandaloneModuleQuantizeHandler,
)
from .quantization_types import (
@@ -488,8 +486,7 @@
# default (no observer)
new_arg = arg
- is_standalone_module = qhandler is not None and \
- isinstance(qhandler, StandaloneModuleQuantizeHandler)
+ is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
assert qconfig is not None
if not is_standalone_module:
# regular flow for most nodes, except standalone modules
@@ -713,8 +710,7 @@
assert qconfig is not None
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
- is_standalone_module = qhandler is not None and \
- isinstance(qhandler, StandaloneModuleQuantizeHandler)
+ is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
dtype = node_name_to_target_dtype[node.name]["output_activation_dtype"]
should_insert_observer = dtype not in DO_NOT_OBS_DTYPE_LIST + [torch.float]
@@ -1251,7 +1247,7 @@
if not maybe_make_input_output_share_observers(node, model, modules):
remove_output_observer(node, model, modules)
- if isinstance(qhandler, CustomModuleQuantizeHandler):
+ if qhandler is not None and qhandler.is_custom_module():
swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config_dict)
else: # output
@@ -1294,7 +1290,7 @@
) in matches.items():
if qhandler is None:
continue
- elif not isinstance(qhandler, StandaloneModuleQuantizeHandler):
+ elif not qhandler.is_standalone_module():
continue
sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index 5d81dcf..9114664 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -25,7 +25,6 @@
import operator
from typing import Any, Callable, Dict, Optional
-# this is temporary, will be removed soon
def _default_root_node_getter(node_pattern):
if node_pattern is None:
return node_pattern
@@ -47,7 +46,9 @@
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
- root_node_getter: Callable = None):
+ root_node_getter: Callable = None,
+ is_custom_module=False,
+ is_standalone_module=False):
""" Records pattern information in __init__, which will be used
in convert
"""
@@ -56,6 +57,8 @@
if root_node_getter is None:
root_node_getter = _default_root_node_getter
self.root_node = root_node_getter(node_pattern)
+ self.is_custom_module_ = is_custom_module
+ self.is_standalone_module_ = is_standalone_module
self.num_tensor_args = 0
# determine how many of the first two args are Tensors (versus scalars)
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
@@ -107,6 +110,12 @@
"""
return qconfig.activation
+ def is_custom_module(self):
+ return self.is_custom_module_
+
+ def is_standalone_module(self):
+ return self.is_standalone_module_
+
@register_quant_pattern(operator.sub)
@register_quant_pattern(operator.mul)
@register_quant_pattern(operator.truediv)
@@ -261,9 +270,6 @@
def is_general_tensor_value_op(self) -> bool:
return True
-class CustomModuleQuantizeHandler(QuantizeHandler):
- pass
-
@register_quant_pattern(torch.nn.Identity)
@register_quant_pattern(torch.transpose)
@register_quant_pattern(torch.repeat_interleave)
@@ -298,8 +304,10 @@
def is_general_tensor_value_op(self) -> bool:
return True
+# TODO: not used, can be removed after torch.quantization namespace is deprecated
+class CustomModuleQuantizeHandler(QuantizeHandler):
+ pass
+
+# TODO: not used, can be removed after torch.quantization namespace is deprecated
class StandaloneModuleQuantizeHandler(QuantizeHandler):
- """ Converts an observed standalone module to quantized standalone module
- by calling convert_fx on the observed standalone module.
- """
pass