blob: a8d48420c8d34b5d91984c1636c758d0b67cc139 [file] [log] [blame]
from torch.fx import (
GraphModule,
Node,
map_arg
)
from torch.fx.graph import Graph
from ..utils import (
get_combined_dict
)
from .graph_module import (
FusedGraphModule
)
from .match_utils import (
is_match,
MatchAllNode,
)
from .pattern_utils import (
get_default_fusion_patterns,
)
from .backend_config.utils import get_fusion_pattern_to_fuse_handler_cls
from .backend_config.utils import get_fuser_method_mapping
from .backend_config.utils import get_fusion_pattern_to_root_node_getter
from .backend_config.utils import get_fusion_pattern_to_extra_inputs_getter
from .fusion_patterns import * # noqa: F401,F403
from typing import Callable, Tuple, Dict, Any, Optional, List
from .quantization_types import Pattern, NodePattern
def fuse(
model: GraphModule,
is_qat: bool,
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
input_root = model
input_graph = model.graph
named_modules = dict(input_root.named_modules())
# TODO: remove this branch after we define the configurations for the
# default/native backend
if backend_config_dict is None:
additional_fusion_patterns = \
fuse_custom_config_dict.get("additional_fusion_pattern", {})
fusion_pattern_to_fuse_handler_cls = get_combined_dict(
get_default_fusion_patterns(), additional_fusion_patterns)
fuser_method_mapping = None
fusion_pattern_to_root_node_getter = {}
fusion_pattern_to_extra_inputs_getter = {}
else:
fusion_pattern_to_fuse_handler_cls = get_fusion_pattern_to_fuse_handler_cls(backend_config_dict)
fuser_method_mapping = get_fuser_method_mapping(backend_config_dict)
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config_dict)
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config_dict)
# find fusion
fusion_pairs = _find_matches(
input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
fused_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
def default_root_node_getter(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
for node in input_graph.nodes:
maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
fusion_pairs.get(node.name, (None, None, None, None, None))
# get the corresponding subpattern for the current node
if node_to_subpattern is not None:
node_subpattern = node_to_subpattern.get(node, None)
else:
node_subpattern = None
if maybe_last_node is node:
assert obj is not None
root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
extra_inputs = []
if extra_inputs_getter is not None:
extra_inputs = extra_inputs_getter(matched_node_pattern)
# TODO: add validation that root_node is a module and has the same type
# as the root_module in the configuration
env[node.name] = obj.fuse(
load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type]
fuse_custom_config_dict, fuser_method_mapping, is_qat)
elif maybe_last_node is None or node_subpattern is MatchAllNode:
env[node.name] = fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here
preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
model = FusedGraphModule(input_root, fused_graph, preserved_attributes)
return model
def _find_matches(
root: GraphModule, graph: Graph,
patterns: Dict[Pattern, Callable]
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
modules = dict(root.named_modules())
# node name -> (root_node, match_value)
match_map : Dict[
str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
# a map from node to the matched subpattern
node_to_subpattern: Dict[Node, Any] = {}
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
if isinstance(pattern, tuple):
s, *args = pattern
current_node_pattern: List[Node] = []
apply_match(s, node, match, current_node_pattern, node_to_subpattern)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
matched_node_pattern.append(tuple(current_node_pattern))
else:
# the first pattern matches will take precedence
if node.name not in match_map:
node_to_subpattern[node] = pattern
matched_node_pattern.append(node)
root_node, pattern, handler = match
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)
for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
matched_node_pattern: List[Node] = []
if is_match(modules, node, pattern):
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
return match_map