ns for fx: rename SugraphTypeRelationship to SubgraphTypeRelationship (#55155)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55155
Fixes typo in enum name, no logic change
Test Plan:
CI
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D27504625
fbshipit-source-id: 21605dadb48225987f1da5ad5f6c30b0183278f2
diff --git a/torch/quantization/ns/graph_matcher.py b/torch/quantization/ns/graph_matcher.py
index ffa1da8..3c242c9 100644
--- a/torch/quantization/ns/graph_matcher.py
+++ b/torch/quantization/ns/graph_matcher.py
@@ -358,7 +358,7 @@
"""
pass
-class SugraphTypeRelationship(enum.Enum):
+class SubgraphTypeRelationship(enum.Enum):
# same type
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
EQUAL = enum.auto()
@@ -390,7 +390,7 @@
gm_a: GraphModule,
gm_b: GraphModule,
type_a_related_to_b: Set[Tuple[Callable, Callable]],
-) -> SugraphTypeRelationship:
+) -> SubgraphTypeRelationship:
node_a = subgraph_a.base_op_node
node_b = subgraph_b.base_op_node
@@ -398,33 +398,33 @@
if node_a.op != node_b.op:
# for now, comparing call_module to call_function is not supported
# this can be added later if needed
- return SugraphTypeRelationship.NOT_RELATED
+ return SubgraphTypeRelationship.NOT_RELATED
if node_a.op == 'call_function':
if node_a.target == node_b.target:
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
if node_a_has_prev and (not node_b_has_prev):
- return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and node_b_has_prev:
- return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and (not node_b_has_prev):
if node_a.target in get_functions_signature_same_across_dtypes():
- return SugraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
+ return SubgraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
else:
- return SugraphTypeRelationship.EQUAL
+ return SubgraphTypeRelationship.EQUAL
else:
# TODO(future PR): check for matches start_op_node and base_op_node
if node_a.target in get_functions_signature_same_across_dtypes():
- return SugraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
+ return SubgraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
else:
- return SugraphTypeRelationship.EQUAL
+ return SubgraphTypeRelationship.EQUAL
key = (node_a.target, node_b.target)
if key in type_a_related_to_b:
- return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
- return SugraphTypeRelationship.NOT_RELATED
+ return SubgraphTypeRelationship.NOT_RELATED
elif node_a.op == 'call_module':
assert (subgraph_a.base_op_node == subgraph_a.start_node and
subgraph_b.base_op_node == subgraph_b.start_node), \
@@ -437,15 +437,15 @@
# modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d)
if type(mod_a) == type(mod_b):
if type(mod_a) in get_module_types_signature_same_across_dtypes():
- return SugraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
+ return SubgraphTypeRelationship.EQUAL_AND_SIGNATURE_SAME_ACROSS_DTYPES
else:
- return SugraphTypeRelationship.EQUAL
+ return SubgraphTypeRelationship.EQUAL
key = (type(mod_a), type(mod_b))
if key in type_a_related_to_b:
- return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
- return SugraphTypeRelationship.NOT_RELATED
- return SugraphTypeRelationship.NOT_RELATED
+ return SubgraphTypeRelationship.NOT_RELATED
+ return SubgraphTypeRelationship.NOT_RELATED
def _get_name_for_subgraph(
subgraph_a: NSSubgraph,
@@ -617,12 +617,12 @@
subgraph_relationship = _get_subgraph_relationship_type(
cur_subgraph_a, cur_subgraph_b,
gm_a, gm_b, type_a_related_to_b)
- if subgraph_relationship == SugraphTypeRelationship.NOT_RELATED:
+ if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
msg = f"""
({cur_subgraph_a}, {type_start_a}) and
({cur_subgraph_b}, {type_start_b}) are not related"""
raise GraphMatchingException(msg)
- elif subgraph_relationship == SugraphTypeRelationship.EQUAL:
+ elif subgraph_relationship == SubgraphTypeRelationship.EQUAL:
# For now, skip nodes with equal types. In the future, this can
# be made configurable.
continue