blob: 4aa9275870c26439dc3be2cb9f3458b82d94ac91 [file] [log] [blame]
import sys
import torch
from torch.fx.graph import (
Graph,
Node,
)
from .quantization_types import Pattern
from .quantization_patterns import (
QuantizeHandler,
CustomModuleQuantizeHandler,
StandaloneModuleQuantizeHandler,
)
from ..qconfig import (
QConfigAny,
)
from .graph_module import (
is_observed_standalone_module,
)
from typing import Any, Dict, List, Callable, Optional, Tuple, Set
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
class MatchAllNode:
""" A node pattern that matches all nodes
"""
pass
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
return True
if len(node.users) > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
def find_matches(
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
qconfig_map: Dict[str, QConfigAny],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Callable] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNodeQuantizeHandler instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, MatchResult] = {}
all_matched : Set[str] = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
cache_for_no_tensor_check: Dict[Node, bool] = dict()
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if is_match(modules, node, pattern):
matched: List[Any] = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (
node, matched, pattern, value(node, modules), # type: ignore[operator]
qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
# add custom module instances to the match result
assert modules is not None
for node in graph.nodes:
if node.op == 'call_module' and \
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None, CustomModuleQuantizeHandler(node, modules),
custom_module_qconfig)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
return (
node_target in standalone_module_names or # type: ignore[operator]
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
)
# add standalone modules to the match
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None,
StandaloneModuleQuantizeHandler(node, modules),
custom_module_qconfig)
return match_map