blob: 83578d7f8dc687a566ec8138e53f0079ab593c84 [file] [log] [blame]
import enum
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Graph, Node
from .utils import getattr_from_fqn
from typing import Dict, Tuple, List, Optional, Set, Callable, Any
def _get_output_nodes(g: Graph) -> List[Node]:
return [n for n in g.nodes if n.op == 'output']
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]] = {
# conv modules
'torch.nn.Conv2d': set([
nn.Conv2d,
nnq.Conv2d,
nnqat.Conv2d,
# Note: matching weights may not work with nniqat.ConvBn2d directly
# leaving that as a problem for a future PR to solve.
nniqat.ConvBn2d,
nniq.ConvReLU2d,
]),
# linear modules
'torch.nn.Linear': set([
nn.Linear,
nnq.Linear,
nnqat.Linear,
nnqd.Linear,
]),
# linear functionals
'torch.nn.functional.linear': set([
F.linear,
toq.linear,
toq.linear_relu,
]),
# LSTM
'torch.nn.LSTM': set([
nn.LSTM,
nnqd.LSTM,
]),
# add
'torch.add': set([
torch.add,
toq.add,
operator.add, # x + y
]),
# cat
'torch.cat': set([
torch.cat,
toq.cat,
]),
# mul
'torch.mul': set([
torch.mul,
toq.mul,
]),
}
return base_name_to_sets_of_related_ops
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]],
) -> Set[Tuple[Callable, Callable]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[Callable, Callable]] = set()
for base_name, s in base_name_to_sets_of_related_ops.items():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list) - 1):
for idx_1 in range(idx_0 + 1, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
def get_non_matchable_functions() -> Set[Callable]:
"""
`call_function` nodes pointing to these functions are non-matchable.
"""
# TODO(future PR): allow customizations
return set([
torch.quantize_per_tensor,
])
def get_non_matchable_modules() -> Set[Callable]:
"""
`call_module` nodes pointing to instances of these types are non-matchable.
"""
# TODO(future PR): allow customizations
return set([
torch.quantization.ObserverBase,
torch.quantization.FakeQuantizeBase,
])
def get_reversed_fusions() -> Set[Tuple[Callable, Callable]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
"""
return set([
(F.relu, F.linear),
(nn.ReLU, nn.Conv2d),
])
# TODO(future PR): we should see if we can reuse quantization's fusion
# patterns here.
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: Tuple[Callable, Callable],
gm: GraphModule,
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
if end_node.op == 'call_function':
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_op = reversed_fusion[fusion_idx]
if cur_node.target != cur_fusion_op:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
return True
elif end_node.op == 'call_module':
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_op = reversed_fusion[fusion_idx]
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_op, type):
return False
if not isinstance(target_mod, cur_fusion_op):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
return True
return False
class _NSGraphMatchableSubgraphsIterator:
"""
Iterates through the graph of gm, starting with the output nodes
and continuing backwards.
1. Returns matchable subgraphs, in order. A subgraph is defined by
(start_node, end_node).
2. Skips over non-matchable subgraphs
"""
def __init__(
self,
gm: GraphModule,
non_matchable_functions: Set[Callable],
non_matchable_modules: Set[Callable],
):
self.gm: GraphModule = gm
self.non_matchable_functions: Set[Callable] = non_matchable_functions
self.non_matchable_modules: Set[Callable] = non_matchable_modules
self.seen_nodes: Set[Node] = set()
self.stack: List[Node] = []
for start_node in _get_output_nodes(self.gm.graph):
self.stack.append(start_node)
def __iter__(self):
return self
def __next__(self) -> Tuple[Node, Node]:
"""
Returns the next matchable subgraph, defined by (start_node, end_node)
"""
while len(self.stack) > 0:
cur_end_node = self.stack.pop()
if cur_end_node in self.seen_nodes:
continue
# for subgraphs which are single nodes, start_node == end_node
# for subgraphs with more than one node, start node != end_node
cur_start_node = cur_end_node
# Check for potential fusions. For now, we are greedy
# and always skip all non-base nodes of a fusion. For example,
# if we match linear-relu backwards, we will always skip the
# relu node and attempt to match the linear node. This can
# be made configurable later if needed.
for _reverse_fusion_ops in get_reversed_fusions():
is_match = end_node_matches_reversed_fusion(
cur_end_node, _reverse_fusion_ops, self.gm)
if is_match:
# navigate to the base node
for fusion_idx in range(len(_reverse_fusion_ops) - 1):
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
cur_start_node = cur_start_node.args[0] # type: ignore
break
self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack
# TODO(future PR): handle kwargs as needed
for arg in cur_start_node.args:
self._recursively_add_node_arg_to_stack(arg)
# skip observers, etc
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
if not self._is_matchable(cur_start_node):
continue
return cur_start_node, cur_end_node
raise StopIteration
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
"""
Adds all of the nodes in this arg to the stack, properly navigating
through list, dicts and tuples.
"""
if isinstance(arg, Node):
self.stack.append(arg)
elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
for inner_arg in arg:
self._recursively_add_node_arg_to_stack(inner_arg)
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
for key, value in arg.items():
self._recursively_add_node_arg_to_stack(value)
def _is_matchable(self, node: Node) -> bool:
if node.op == 'call_function':
return not (node.target in self.non_matchable_functions)
elif node.op == 'call_module':
assert isinstance(node.target, str)
# target_mod = getattr(self.gm, node.target)
target_mod = getattr_from_fqn(self.gm, node.target)
return not \
any(isinstance(target_mod, t) # type: ignore
for t in self.non_matchable_modules)
else:
return False
class GraphMatchingException(Exception):
"""
Exception raised when two graphs cannot be matched.
"""
pass
class NodeTypeRelationship(enum.Enum):
# same type
# example: F.linear and toq.linear, or nn.Conv2d and nn.Conv2d
EQUAL = enum.auto()
# same node_relationship set, but not the same type
# example: F.linear and toq.linear
RELATED_BUT_NOT_EQUAL = enum.auto()
# not related
NOT_RELATED = enum.auto()
def _get_node_relationship_type(
node_a: Node,
node_b: Node,
gm_a: GraphModule,
gm_b: GraphModule,
type_a_related_to_b: Set[Tuple[Callable, Callable]],
) -> NodeTypeRelationship:
if node_a.op != node_b.op:
# for now, comparing call_module to call_function is not supported
# this can be added later if needed
return NodeTypeRelationship.NOT_RELATED
if node_a.op == 'call_function':
if node_a.target == node_b.target:
# nodes with equivalent targets always match (i.e. F.linear and F.linear)
return NodeTypeRelationship.EQUAL
key = (node_a.target, node_b.target)
if key in type_a_related_to_b:
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return NodeTypeRelationship.NOT_RELATED
elif node_a.op == 'call_module':
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
mod_a = getattr_from_fqn(gm_a, node_a.target)
assert isinstance(node_b.target, str)
mod_b = getattr_from_fqn(gm_b, node_b.target)
# modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d)
if type(mod_a) == type(mod_b):
return NodeTypeRelationship.EQUAL
key = (type(mod_a), type(mod_b))
if key in type_a_related_to_b:
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return NodeTypeRelationship.NOT_RELATED
return NodeTypeRelationship.NOT_RELATED
def _get_name_for_subgraph(
start_node_a: Node,
end_node_a: Node,
gm_a: GraphModule,
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]],
existing_names: Set[str],
) -> str:
"""
Returns a unique name for a subgraph. This name is based on two things:
1. the name of the set containing the underlying type of the base op in the
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
2. the number of previous subgraphs with related underlying type of the base op
For example, in the graph
linear0 -> relu0 -> linear1 -> relu1
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
from the output node backwards, the name given to (linear1, relu1) will be
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
will be `base_op_torch.nn.functional.linear_1`.
Why are we not just using the node name? Answer: because of two requirements:
A. fusions must be supported
B. some Numeric Suite APIs can be called without having all of the models in memory
For example, let's say we need to match nodes of
(1) ... -> linear0 -> relu0 -> ...
And
(2) ... -> linear_relu0 -> ...
Without being able to inspect them together. With the current naming scheme, if
we iterate through both of these graphs in the same order, and assuming the rest
of the graphs match, both of these subgraphs will get the same name without
(1) and (2) knowing anything about each other.
"""
target_type = _get_node_target_type(start_node_a, gm_a)
target_base_type = None
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
if target_type in sets_of_related_ops:
target_base_type = base_name
target_base_name = 'base_op_' + str(target_base_type)
counter = 0
proposed_name = target_base_name + '_' + str(counter)
while proposed_name in existing_names:
counter += 1
proposed_name = target_base_name + '_' + str(counter)
existing_names.add(proposed_name)
return proposed_name
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[Callable]:
if node.op == 'call_function':
return node.target # type: ignore
elif node.op == 'call_module':
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
return type(mod)
return None
def get_matching_subgraph_pairs(
gm_a: GraphModule,
gm_b: GraphModule,
) -> Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]]:
"""
Matches matchable subgraphs of graph_a to graph_b.
For a node, "matchable" is defined as a node which is not an observer,
fake_quants, quant or dequant.
A subgraph can contain one or more nodes. A subgraph is matchable if
at least one node inside of it is matchable. Currently, all nodes in
a subgraph must be matchable (because we assume no observers will be
inserted in the middle of a fusion).
A subgraph is defined by (start_node, end_node). We assume that only
start_node and end_node are linked with the surrounding graph, all other
nodes in a subgraph are self-contained.
A pair of nodes is "related" if both nodes represent the same mathematical
operation across different quantization flavors. For example,
`F.linear` and `torch.ops.quantized.linear` are related, and
`F.linear` and `torch.nn.Conv` are not related.
For each matchable pair of nodes node_a and node_b, they will match
if node_a and node_b are related.
For graphs A and B, they will match iff:
1. the number of matchable subgraphs in A and B is equivalent
2. when iterating through the matchable subgraphs of A and B in the same order, each
corresponding pair of base nodes is related.
This enables us to find the corresponding subgraphs between
graphs of related models. For example, if we had two graphs such as:
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
w -/
b -/
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
packed_params_0 -/
This function will return the following result:
{
'conv_0': ( # the name of the node in graph_b
(conv_0, conv_0), # (start_node_a, end_node_a)
(qconv_0, qconv_0), # (start_node_b, end_node_b)
),
}
Or, if we have a fusion pattern,
graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
w -/
b -/
graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
packed_params_0 -/
This function will return the following result:
{
'linear_relu_0': ( # the name of the node in graph_b
(linear_0, relu_0), # (start_node_a, end_node_a)
(linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
),
}
"""
non_matchable_functions = get_non_matchable_functions()
non_matchable_modules = get_non_matchable_modules()
graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
gm_a, non_matchable_functions, non_matchable_modules)
graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
gm_b, non_matchable_functions, non_matchable_modules)
results = {}
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
existing_names_a: Set[str] = set()
existing_names_b: Set[str] = set()
while True:
# fetch the next nodes from a and b
cur_start_node_a, cur_start_node_b = None, None
cur_end_node_a, cur_end_node_b = None, None
try:
cur_start_node_a, cur_end_node_a = next(graph_a_iterator)
except StopIteration:
pass
try:
cur_start_node_b, cur_end_node_b = next(graph_b_iterator)
except StopIteration:
pass
# look up types of a and b for useful error messages
type_start_a, type_start_b = None, None
if cur_end_node_a is not None:
type_start_a = _get_node_target_type(cur_start_node_a, gm_a) # type: ignore
if cur_end_node_b is not None:
type_start_b = _get_node_target_type(cur_start_node_b, gm_b) # type: ignore
# check for results and determine what to do next
if cur_end_node_a is not None and cur_end_node_b is not None:
assert isinstance(cur_start_node_a, Node)
assert isinstance(cur_start_node_b, Node)
# both nodes were fetched, check for node_relationship
# note: node_relationship is checked on the start node, i.e.
# if a linear-relu pattern is checked, we would check for node_relationship
# of the linear
node_relationship = _get_node_relationship_type(
cur_start_node_a, cur_start_node_b,
gm_a, gm_b, type_a_related_to_b)
if node_relationship == NodeTypeRelationship.NOT_RELATED:
msg = f"""
({cur_start_node_a}, {type_start_a}) and
({cur_start_node_b}, {type_start_b}) are not related"""
raise GraphMatchingException(msg)
elif node_relationship == NodeTypeRelationship.EQUAL:
# For now, skip nodes with equal types. In the future, this can
# be made configurable.
continue
key_name_a = _get_name_for_subgraph(
cur_start_node_a, cur_end_node_a, gm_a, base_name_to_sets_of_related_ops,
existing_names_a)
key_name_b = _get_name_for_subgraph(
cur_start_node_b, cur_end_node_b, gm_b, base_name_to_sets_of_related_ops,
existing_names_b)
assert key_name_a == key_name_b, \
f"Subgraph names {key_name_a} and {key_name_b} do not match"
results[key_name_a] = (
(cur_start_node_a, cur_end_node_a),
(cur_start_node_b, cur_end_node_b),
)
continue
elif cur_end_node_a is None and cur_end_node_b is None:
# we reached the end of both graphs
break
else:
# only one node was fetched, no match possible, throw error
msg = f"""
Matchable nodes count mismatch: ({cur_start_node_a}, {type_start_a}) and
({cur_start_node_b}, {type_start_b})"""
raise GraphMatchingException(msg)
return results