[quant][graphmode][fx] Add input_idx_to_dtype and ouptut_idx_to_dtype to backend_config_dict (#67067)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67067
Plan to gradually adding features to backend_config_dict, this PR adds support
for specifying the dtype for input and output of a given pattern
Test Plan: Imported from OSS
Reviewed By: vkuzo
Differential Revision: D31849074
fbshipit-source-id: ca2fbb873176fe72e08ea79ed1bc659bf27cbd8a
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index c4dc4f0..916ba5b 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -5381,7 +5381,7 @@
),
weight=torch.ao.quantization.default_weight_observer
)
- self.backend_config_dict = get_tensorrt_backend_config_dict()
+ self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
def test_conv(self):
class Conv2d(torch.nn.Module):
@@ -5396,7 +5396,7 @@
conv2d_module_args = (3, 3, 3)
m = Conv2d(*conv2d_module_args).eval()
- prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.backend_config_dict)
+ prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
# calibration
prepared(conv2d_input)
quantized = _convert_fx_do_not_use(prepared, is_reference=True)
@@ -5422,7 +5422,7 @@
linear_module_input = torch.rand(8, 5)
m = LinearModule().eval()
- prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.backend_config_dict)
+ prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
# calibration
prepared(linear_module_input)
quantized = _convert_fx_do_not_use(prepared, is_reference=True)
@@ -5441,6 +5441,34 @@
# make sure it runs
trt_mod(linear_module_input.cuda())
+ def test_unsupported_qconfig(self):
+ """ Check that we won't quantize the model if the qconfig is not supported
+ """
+ class LinearModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 10)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ linear_module_input = torch.rand(8, 5)
+
+ m = LinearModule().eval()
+ trt_unsupported_qconfig = default_qconfig
+ prepared = prepare_fx(m, {"": trt_unsupported_qconfig}, backend_config_dict=self.trt_backend_config_dict)
+ # calibration
+ prepared(linear_module_input)
+ quantized = _convert_fx_do_not_use(prepared, is_reference=True)
+ node_occurrence = {
+ ns.call_function(torch.quantize_per_tensor): 0,
+ ns.call_method("dequantize"): 0,
+ ns.call_module(torch.nn.Linear): 1,
+ ns.call_module(torch.nn.quantized._reference.Linear): 0,
+ }
+ # check model is not quantized
+ self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
+
class TestQuantizeFxModels(QuantizationTestCase):
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
diff --git a/torch/ao/quantization/fx/_convert_do_not_use.py b/torch/ao/quantization/fx/_convert_do_not_use.py
index 145701d..09daf19 100644
--- a/torch/ao/quantization/fx/_convert_do_not_use.py
+++ b/torch/ao/quantization/fx/_convert_do_not_use.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Tuple, List
+from typing import Any, Dict, List
import torch
from torch.fx import (
GraphModule,
@@ -14,17 +14,12 @@
get_qparam_dict,
)
-from .quantization_types import Pattern
from .match_utils import (
find_matches,
)
from .graph_module import (
- is_observed_module,
QuantizedGraphModule,
)
-from .quantization_patterns import (
- QuantizeHandler,
-)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from .utils import (
get_custom_module_class_keys,
@@ -38,24 +33,25 @@
is_activation_post_process,
)
-
-def restore_state(
- observed: GraphModule
-) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]], Dict[str, Any]]:
- assert is_observed_module(observed), \
- 'incoming model must be produced by prepare_fx'
- prepare_custom_config_dict: Dict[str, Any] = \
- observed._prepare_custom_config_dict # type: ignore[assignment]
- node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
- patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
- return patterns, node_name_to_scope, prepare_custom_config_dict
+from .convert import restore_state
def _convert_do_not_use(
model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
- """ standalone_module means it a submodule that is not inlined in
+ """
+ We will convert an observed model (a module with observer calls) to a reference
+ quantized model, the rule is simple:
+ 1. for each observer module call in the graph, we'll convert it to calls to
+ quantize and dequantize functions based on the observer instance
+ 2. for weighted operations like linear/conv, we need to convert them to reference
+ quantized module, this requires us to know whether the dtype configured for the
+ weight is supported in the backend, this is done in prepare step and the result
+ is stored in observed_node_names, we can decide whether we need to swap the
+ module based on this set
+
+ standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Returns a quantized standalone module, whether input/output is quantized is
@@ -65,7 +61,7 @@
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
- patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(model)
+ patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
assert is_reference, "convert2 only supports reference option"
@@ -178,8 +174,11 @@
torch.nn.Conv3d]:
fmodule = modules[node.target]
qconfig = fmodule.qconfig
+
+ is_observed = node.name in observed_node_names
+ is_weight_quantized = weight_is_statically_quantized(qconfig)
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
- if qconfig is not None and weight_is_statically_quantized(qconfig):
+ if qconfig is not None and is_observed and is_weight_quantized:
weight_post_process = qconfig.weight()
# run weight observer
weight_post_process(fmodule.weight) # type: ignore[operator]
diff --git a/torch/ao/quantization/fx/backend_config_dict/README.md b/torch/ao/quantization/fx/backend_config_dict/README.md
new file mode 100644
index 0000000..e17d728
--- /dev/null
+++ b/torch/ao/quantization/fx/backend_config_dict/README.md
@@ -0,0 +1,22 @@
+# Backend Configuration
+## Pattern Format
+The patterns are we matching against is float modules types, functional operators and pytorch operators in reverse order:
+```
+operator = module_type | functional | torch op | native op | MatchAllNode
+Pattern = (operator, Pattern, Pattern, ...) | operator
+```
+where the first item for Pattern is the operator we want to match, and the rest are the patterns for the arguments of the operator.
+For example, pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph:
+```
+tensor_1 tensor_2
+ | |
+ *(MatchAllNode) nn.Conv2d
+ | |
+ | nn.BatchNorm2d
+ \ /
+ -- operator.add --
+ |
+ nn.ReLU
+```
+
+we’ll match the last node as the anchor point of the match, and we can retrieve the whole graph by tracing back from the node, e.g. in the example above, we matched nn.ReLU node, then node.args[0] is the operator.add node.
diff --git a/torch/ao/quantization/fx/backend_config_dict/__init__.py b/torch/ao/quantization/fx/backend_config_dict/__init__.py
index 96f0631..c8cbd98 100644
--- a/torch/ao/quantization/fx/backend_config_dict/__init__.py
+++ b/torch/ao/quantization/fx/backend_config_dict/__init__.py
@@ -1,4 +1,4 @@
from .tensorrt import get_tensorrt_backend_config_dict
def validate_backend_config_dict(backend_config_dict):
- return "quant_patterns" in backend_config_dict
+ return "configs" in backend_config_dict
diff --git a/torch/ao/quantization/fx/backend_config_dict/quantize_handler.py b/torch/ao/quantization/fx/backend_config_dict/quantize_handler.py
new file mode 100644
index 0000000..d5dce81
--- /dev/null
+++ b/torch/ao/quantization/fx/backend_config_dict/quantize_handler.py
@@ -0,0 +1,16 @@
+import torch
+from typing import Dict
+from torch.fx.graph import Node
+from .observation_type import ObservationType
+from ..quantization_patterns import QuantizeHandler
+
+def get_quantize_handler_cls(observation_type, pattern_configs):
+ assert observation_type == ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, \
+ "Only OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT is supported right now"
+
+ class ConfigurableQuantizeHandler(QuantizeHandler):
+ def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
+ super().__init__(node, modules)
+ self.pattern_configs = pattern_configs
+
+ return ConfigurableQuantizeHandler
diff --git a/torch/ao/quantization/fx/backend_config_dict/tensorrt.py b/torch/ao/quantization/fx/backend_config_dict/tensorrt.py
index fa91bfd..1a53a1c 100644
--- a/torch/ao/quantization/fx/backend_config_dict/tensorrt.py
+++ b/torch/ao/quantization/fx/backend_config_dict/tensorrt.py
@@ -1,15 +1,41 @@
import torch
-from ..quantization_patterns import ConvReluQuantizeHandler, LinearReLUQuantizeHandler
+from .observation_type import ObservationType
def get_tensorrt_backend_config_dict():
""" Get the backend config dictionary for tensorrt backend
NOTE: Current api will change in the future, it's just to unblock experimentation for
new backends, please don't use it right now.
"""
- quant_patterns = {
- torch.nn.Conv2d: ConvReluQuantizeHandler,
- torch.nn.Linear: LinearReLUQuantizeHandler
+ weighted_op_qint8_dtype_config = {
+ # optional, input activation dtype
+ "input_dtype": torch.qint8,
+ # optional, weight dtype
+ "weight_dtype": torch.qint8,
+ # optional, bias dtype
+ "bias_dtype": torch.float,
+ # optional, output activation dtype
+ "output_dtype": torch.qint8
+ }
+ linear_module_config = {
+ # Please see README under this folder for pattern format
+ "pattern": torch.nn.Linear,
+ "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+ "dtype_configs": [
+ weighted_op_qint8_dtype_config,
+ ]
+ }
+ conv_module_config = {
+ "pattern": torch.nn.Conv2d,
+ "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+ "dtype_configs": [
+ weighted_op_qint8_dtype_config,
+ ]
}
return {
- "quant_patterns": quant_patterns
+ # optional
+ "name": "tensorrt",
+ "configs": [
+ linear_module_config,
+ conv_module_config,
+ ]
}
diff --git a/torch/ao/quantization/fx/backend_config_dict/utils.py b/torch/ao/quantization/fx/backend_config_dict/utils.py
new file mode 100644
index 0000000..e252200
--- /dev/null
+++ b/torch/ao/quantization/fx/backend_config_dict/utils.py
@@ -0,0 +1,32 @@
+import torch
+from .quantize_handler import get_quantize_handler_cls
+from typing import Dict, Any, List
+from ..quantization_types import Pattern, QuantizerCls
+
+def get_pattern_to_quantize_handlers(
+ backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
+ """
+ Note: Quantize handler is just a holder for some check methods like
+ (should_insert_observer_for_output), maybe this can be a enum as well,
+ we can refactor this after we convert the path for fbgemm/qnnpack fully to the
+ new path, this is not exposed to backend developers
+ """
+ pattern_to_quantize_handlers = dict()
+ for config in backend_config_dict["configs"]:
+ pattern = config["pattern"]
+ observation_type = config["observation_type"]
+ dtype_configs = config["dtype_configs"]
+ pattern_to_quantize_handlers[pattern] = \
+ get_quantize_handler_cls(observation_type, dtype_configs)
+
+ return pattern_to_quantize_handlers
+
+
+def get_pattern_to_dtype_configs(
+ backend_config_dict: Dict[str, Any]) -> Dict[Pattern, List[Dict[str, torch.dtype]]]:
+ pattern_to_dtype_configs: Dict[Pattern, List[Dict[str, torch.dtype]]] = dict()
+ for config in backend_config_dict["configs"]:
+ pattern = config["pattern"]
+ dtype_configs = config["dtype_configs"]
+ pattern_to_dtype_configs[pattern] = dtype_configs
+ return pattern_to_dtype_configs
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 143627f..d1a0546 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Tuple, List, Callable, Optional, Union
+from typing import Any, Dict, Tuple, List, Callable, Optional, Union, Set
from collections import defaultdict
import copy
import torch
@@ -205,14 +205,18 @@
def restore_state(
observed: GraphModule
-) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]], Dict[str, Any]]:
+) -> Tuple[Dict[Pattern, QuantizeHandler],
+ Dict[str, Tuple[str, type]],
+ Dict[str, Any],
+ Set[str]]:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
prepare_custom_config_dict: Dict[str, Any] = \
observed._prepare_custom_config_dict # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
- return patterns, node_name_to_scope, prepare_custom_config_dict
+ observed_node_names: Set[str] = observed._observed_node_names # type: ignore[assignment]
+ return patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names
def convert(model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
@@ -229,7 +233,7 @@
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
- patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(model)
+ patterns, node_name_to_scope, prepare_custom_config_dict, _ = restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
# TODO this should be removed now that gpu support for quantization is being supported.
diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py
index 7e6f1d9..133d859 100644
--- a/torch/ao/quantization/fx/graph_module.py
+++ b/torch/ao/quantization/fx/graph_module.py
@@ -32,7 +32,8 @@
'_equalization_qconfig_map',
'_node_name_to_scope',
'_qconfig_dict',
- '_is_training']).union(preserved_attr_names)
+ '_is_training',
+ '_observed_node_names']).union(preserved_attr_names)
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 4fb3f5c..9fa9577 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -84,7 +84,12 @@
weight_dtype,
)
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from .backend_config_dict.utils import (
+ get_pattern_to_quantize_handlers,
+ get_pattern_to_dtype_configs,
+)
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
@@ -113,6 +118,83 @@
node.kwargs.get('bias', None) is arg)
)
+def is_input_arg_dtype_supported_by_backend(
+ arg: Argument,
+ node: Node,
+ node_name_to_target_dtype: Dict[str, Any],
+ dtype_config: Dict[str, torch.dtype],
+) -> bool:
+ """ Check if the configured qconfig for the argument
+ is supported by the backend or not
+ """
+ if isinstance(arg, (list, tuple)):
+ return all(map(lambda a: is_input_arg_dtype_supported_by_backend(a, node, node_name_to_target_dtype, dtype_config), arg))
+ if not isinstance(arg, Node):
+ return True
+ # TODO: support check for standalone module
+ is_weight = node_arg_is_weight(node, arg)
+ is_bias = node_arg_is_bias(node, arg)
+ is_activation = not (is_weight or is_bias)
+ input_activation_dtype = dtype_config.get("input_activation_dtype", None)
+ if is_activation:
+ return input_activation_dtype is None or node_name_to_target_dtype[arg.name] == input_activation_dtype
+ elif is_weight:
+ # TODO: we need to refactor get_target_activation_dtype_for_node to include
+ # weight, and maybe have a separate current_node_name_to_dtype dict
+ # return weight_dtype is None or node_name_to_target_dtype[arg.name] == weight_dtype
+ raise RuntimeError("weight is not handled yet")
+ elif is_bias:
+ # Note: config for bias is not supported in qconfig currently
+ raise RuntimeError("bias is not handled yet")
+ return True
+
+def is_output_dtype_supported_by_backend(
+ node: Node,
+ node_name_to_target_dtype: Dict[str, Any],
+ dtype_config: Dict[str, torch.dtype],
+) -> bool:
+ """ Check if the configured qconfig for the output
+ is supported by the backend or not
+ """
+ output_dtype = dtype_config.get("output_dtype", None)
+ return output_dtype is None or output_dtype == node_name_to_target_dtype[node.name]
+
+
+def is_pattern_dtype_config_supported_by_backend(
+ pattern: Optional[Pattern],
+ matched_nodes: Optional[List[Node]],
+ node_name_to_target_dtype: Dict[str, Any],
+ backend_config_dict: Optional[Dict[str, Any]]
+) -> bool:
+ """ Check is the dtype configuration of a pattern is supported by
+ the backend or not
+ """
+ if backend_config_dict is None or pattern is None:
+ return True
+ assert matched_nodes is not None and len(matched_nodes) >= 1
+ pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config_dict)
+ dtype_configs: List[Dict[str, torch.dtype]] = pattern_to_dtype_configs.get(pattern, [])
+
+ input_node = matched_nodes[0]
+ output_node = matched_nodes[-1]
+ for dtype_config in dtype_configs:
+ # check if arg dtype are supported
+ supported = True
+ for arg in input_node.args:
+ supported = supported and \
+ is_input_arg_dtype_supported_by_backend(
+ arg, input_node, node_name_to_target_dtype, dtype_config)
+ for k, arg in input_node.kwargs.items():
+ supported = supported and \
+ is_input_arg_dtype_supported_by_backend(
+ arg, input_node, node_name_to_target_dtype, dtype_config)
+ # check if output dtype is supported
+ supported = supported and is_output_dtype_supported_by_backend(
+ output_node, node_name_to_target_dtype, dtype_config)
+ if supported:
+ return True
+ return False
+
def get_standalone_module_configs(
node: Node,
modules: Dict[str, torch.nn.Module],
@@ -793,6 +875,8 @@
equalization_config_map: Dict[str, Any],
input_quantized_idxs: List[int],
output_quantized_idxs: List[int],
+ backend_config_dict: Optional[Dict[str, Any]],
+ observed_node_names: Set[str],
) -> Optional[Node]:
"""
Inserts observers, using the following high level algorithm:
@@ -884,11 +968,20 @@
(qconfig is None) or
output_not_a_tensor or
is_getitem
- ) and (not node.op == 'output')
+ ) and (
+ not node.op == 'output'
+ )
- if not skip_inserting_observers:
+ is_supported_by_backend = is_pattern_dtype_config_supported_by_backend(
+ pattern, matched_nodes, node_name_to_target_dtype, backend_config_dict)
+
+ if not skip_inserting_observers and is_supported_by_backend:
modules = dict(model.named_modules(remove_duplicate=False))
if node.op != 'output':
+ assert matched_nodes is not None
+ # add matched nodes to the observed node name set
+ for n in matched_nodes:
+ observed_node_names.add(n.name)
# This is currently only used for equalization.
# Checks if the current node is in a branch in which the two
@@ -1039,6 +1132,7 @@
equalization_qconfig_map: Dict[str, Any],
qconfig_dict: Dict[str, Dict[Any, Any]],
is_training: bool,
+ observed_node_names: Set[str],
) -> None:
observed._patterns = patterns # type: ignore[assignment]
observed._qconfig_map = qconfig_map # type: ignore[assignment]
@@ -1048,6 +1142,7 @@
observed._equalization_qconfig_map = equalization_qconfig_map # type: ignore[assignment]
observed._qconfig_dict = qconfig_dict # type: ignore[assignment]
observed._is_training = is_training # type: ignore[assignment]
+ observed._observed_node_names = observed_node_names # type: ignore[assignment]
def prepare(
model: GraphModule,
@@ -1094,9 +1189,13 @@
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.ao.quantization.fx.quantize.Add'>),
# }
- quant_patterns = get_default_quant_patterns()
- patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
- quant_patterns, additional_quant_patterns)
+ patterns: Dict[Pattern, QuantizeHandler] = {}
+ if backend_config_dict is None:
+ quant_patterns = get_default_quant_patterns()
+ patterns = get_combined_dict(
+ quant_patterns, additional_quant_patterns)
+ else:
+ patterns = get_pattern_to_quantize_handlers(backend_config_dict)
convert_dict_to_ordered_dict(qconfig_dict)
convert_dict_to_ordered_dict(equalization_qconfig_dict)
@@ -1148,14 +1247,23 @@
run_prepare_fx_on_standalone_modules(
model, modules, matches, prepare_custom_config_dict)
+ # record names for the set of observed node, so that in convert step
+ # we know whether we need to convert a floating point module to reference
+ # quantized module or not
+ observed_node_names: Set[str] = set()
+
result_node = insert_observers_for_model(
model, modules, matches, qconfig_map,
model.graph, prepare_custom_config_dict,
equalization_qconfig_map,
- input_quantized_idxs, output_quantized_idxs)
+ input_quantized_idxs,
+ output_quantized_idxs,
+ backend_config_dict,
+ observed_node_names)
save_state(model, qconfig_map, node_name_to_scope, patterns,
- prepare_custom_config_dict, equalization_qconfig_map, qconfig_dict, model.training)
+ prepare_custom_config_dict, equalization_qconfig_map, qconfig_dict, model.training, observed_node_names)
+
preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
model = ObservedGraphModule(model, model.graph, preserved_attributes)
if is_standalone_module:
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index f9ace07..3bc48ad 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -47,7 +47,7 @@
from ..qconfig import QConfigAny
-from abc import ABC, abstractmethod
+from abc import ABC
import operator
import warnings
@@ -169,8 +169,6 @@
"""
return True
-
- @abstractmethod
def convert(self,
node: Node,
qconfig: QConfigAny,