ns for fx: make unshadowed activation comparison work for N models (#52357)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52357
Refactor the NS for FX compare unshadowed activations API to be able
to work on N models and do arbitrary matching strategies.
We factor out a util which takes a model and a list of
nodes to extract weights for, with names to give the extracted
weights. The user can then call this util with a set
of nodes and names created in any way they want.
Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```
Imported from OSS
Reviewed By: raghuramank100
Differential Revision: D26487270
fbshipit-source-id: 1372ef07b5f3ddc7cebdfb2dee0221a2facd0527
diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py
index 724bd49..de5d422 100644
--- a/test/quantization/test_numeric_suite_fx.py
+++ b/test/quantization/test_numeric_suite_fx.py
@@ -676,7 +676,8 @@
mq_ns(input_fp32)
# check activation result correctness
- act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
+ act_compare_dict = get_matching_activations(
+ 'fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
@@ -727,7 +728,8 @@
mq_ns(input_fp32)
# check activation result correctness
- act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
+ act_compare_dict = get_matching_activations(
+ 'fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
@@ -834,7 +836,8 @@
sparse_nn_q(idx, offsets, x)
# inspect results
- act_compare_dict = get_matching_activations(sparse_nn, sparse_nn_q, OutputLogger)
+ act_compare_dict = get_matching_activations(
+ 'fp32_prepared', sparse_nn, 'int8', sparse_nn_q, OutputLogger)
self.assertTrue(len(act_compare_dict) == 3)
self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
diff --git a/torch/quantization/ns/graph_passes.py b/torch/quantization/ns/graph_passes.py
index 25f9bb5..3447154 100644
--- a/torch/quantization/ns/graph_passes.py
+++ b/torch/quantization/ns/graph_passes.py
@@ -45,7 +45,7 @@
def remove_observers_add_loggers(
gm: GraphModule,
- node_to_instrument_to_other_node_name: Dict[Node, Optional[str]],
+ node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
logger_cls: Callable,
model_name: str,
) -> GraphModule:
@@ -71,8 +71,8 @@
# remove activation post process node
env[node.name] = env[node.args[0].name]
- elif node in node_to_instrument_to_other_node_name:
- other_node_name = node_to_instrument_to_other_node_name[node]
+ elif node in node_to_instrument_to_ref_node_name:
+ other_node_name = node_to_instrument_to_ref_node_name[node]
# ensure env is populated with base node
env[node.name] = new_graph.node_copy(node, load_arg)
# add the logger after the base node
diff --git a/torch/quantization/ns/numeric_suite_core_apis_fx.py b/torch/quantization/ns/numeric_suite_core_apis_fx.py
index a89867e..e7e18c3 100644
--- a/torch/quantization/ns/numeric_suite_core_apis_fx.py
+++ b/torch/quantization/ns/numeric_suite_core_apis_fx.py
@@ -103,7 +103,7 @@
self,
node_name: str,
model_name: str,
- other_node_name: Optional[str] = None,
+ ref_node_name: Optional[str] = None,
):
super().__init__()
self.stats: List[torch.Tensor] = []
@@ -111,17 +111,38 @@
self.node_name = node_name
# name of the model from which the node originated from
self.model_name = model_name
- # name of the other node with a matching Logger
+ # name of the reference node with a matching Logger
# used to link node_a_copy -> logger_a to node_c -> logger_c
# in a_shadows_b
- self.other_node_name = other_node_name
+ self.ref_node_name = ref_node_name
def forward(self, x: torch.Tensor):
self.stats.append(x.detach())
return x
def __repr__(self):
- return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, other_node_name={self.other_node_name})"
+ return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, ref_node_name={self.ref_node_name})"
+
+def prepare_single_model_output(
+ model_name: str,
+ model: GraphModule,
+ subgraphs_to_instrument: List[Tuple[Tuple[Node, Node], str]],
+ logger_cls: Callable,
+) -> GraphModule:
+
+ # TODO(future PR): do not observe nodes we do not care
+ # about (both fp32, denylist, etc)
+ # Note: for matching activations we always use the end nodes,
+ # such as observing the output of relu in linear-relu
+ # Note: ref_node_name is set to None in model B's loggers,
+ # and set to the corresponding model B's node in model A's loggers.
+ node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]] = {}
+ for (node_start, node_end), ref_node_name in subgraphs_to_instrument:
+ node_to_instrument_to_ref_node_name[node_end] = ref_node_name
+
+ model = remove_observers_add_loggers(
+ model, node_to_instrument_to_ref_node_name, logger_cls, model_name)
+ return model
# Note: this is not a user facing API
# TODO(future PR): wrap this in a user facing API which does not
@@ -133,35 +154,49 @@
gm_b: GraphModule,
logger_cls: Callable,
) -> Tuple[GraphModule, GraphModule]:
-
matched_subgraph_pairs = get_matching_subgraph_pairs(gm_a, gm_b)
+ subgraphs_to_instrument_a = []
+ subgraphs_to_instrument_b = []
+ for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
+ subgraphs_to_instrument_a.append((subgraph_a, match_name))
+ subgraphs_to_instrument_b.append((subgraph_b, match_name))
- node_to_instrument_to_other_node_name_a: Dict[Node, Optional[str]] = {}
- node_to_instrument_to_other_node_name_b: Dict[Node, Optional[str]] = {}
- for match_name, match in matched_subgraph_pairs.items():
- (node_start_a, node_end_a), (node_start_b, node_end_b) = match
- # TODO(future PR): do not observe pairs of nodes we do not care
- # about (both fp32, denylist, etc)
- # Note: for matching activations we always use the end nodes,
- # such as observing the output of relu in linear-relu
- # Note: other_node_name is set to None in model B's loggers,
- # and set to the corresponding model B's node in model A's loggers.
- node_to_instrument_to_other_node_name_a[node_end_a] = node_end_b.name
- node_to_instrument_to_other_node_name_b[node_end_b] = None
-
- gm_a = remove_observers_add_loggers(
- gm_a, node_to_instrument_to_other_node_name_a, logger_cls, name_a)
- gm_b = remove_observers_add_loggers(
- gm_b, node_to_instrument_to_other_node_name_b, logger_cls, name_b)
+ gm_a = prepare_single_model_output(
+ name_a, gm_a, subgraphs_to_instrument_a, logger_cls)
+ gm_b = prepare_single_model_output(
+ name_b, gm_b, subgraphs_to_instrument_b, logger_cls)
return (gm_a, gm_b)
+def add_activation_info_to_dict(
+ model_name: str,
+ model: GraphModule,
+ results: Dict[str, Dict[str, List[torch.Tensor]]],
+ logger_cls: Callable,
+) -> None:
+ for gm_name, mod in model.named_modules():
+ # TODO(future PR): better check when scripted
+ is_logger = (
+ isinstance(mod, logger_cls) # type: ignore
+ or (
+ isinstance(mod, torch.jit.RecursiveScriptModule)
+ and mod.original_name == 'OutputLogger'
+ )
+ )
+ if is_logger:
+ key = mod.ref_node_name + '.stats'
+ if key not in results:
+ results[key] = {}
+ results[key][model_name] = mod.stats
+
# Note: this is not a user facing API
# TODO(future PR): wrap this in a user facing API which does not
# expose FX types.
# TODO(future PR): align on naming
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
def get_matching_activations(
+ model_name_a: str,
gm_a: GraphModule,
+ model_name_b: str,
gm_b: GraphModule,
logger_cls: Callable,
) -> Dict[str, Dict[str, List[torch.Tensor]]]:
@@ -188,26 +223,11 @@
the return type for calibrating with 1 input vs N inputs.
3. `logger_cls` is included in the API for easy result extraction
"""
- results: Dict[str, Dict[str, List[torch.Tensor]]] = \
- collections.defaultdict(dict)
+ results: Dict[str, Dict[str, List[torch.Tensor]]] = {}
for gm in (gm_a, gm_b):
- for gm_name, mod in gm.named_modules():
- # TODO(future PR): better check when scripted
- is_logger = (
- isinstance(mod, logger_cls) # type: ignore
- or (
- isinstance(mod, torch.jit.RecursiveScriptModule)
- and mod.original_name == 'OutputLogger'
- )
- )
- if is_logger:
- # If logger_obj.other_node_name is populated, then this logger
- # is from model A, and other_node_name is the name from model B.
- if mod.other_node_name is None:
- results[mod.node_name + '.stats'][mod.model_name] = mod.stats
- else:
- results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
- return dict(results)
+ add_activation_info_to_dict(model_name_a, gm_a, results, logger_cls)
+ add_activation_info_to_dict(model_name_b, gm_b, results, logger_cls)
+ return results
# Note: this is not a user facing API
# TODO(future PR): wrap this in a user facing API which does not
@@ -253,10 +273,10 @@
)
)
if is_logger:
- # If logger_obj.other_node_name is populated, then this logger
- # is from model A, and other_node_name is the name from model B.
- if mod.other_node_name is None:
+ # If logger_obj.ref_node_name is populated, then this logger
+ # is from model A, and ref_node_name is the name from model B.
+ if mod.ref_node_name is None:
results[mod.node_name + '.stats'][mod.model_name] = mod.stats
else:
- results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
+ results[mod.ref_node_name + '.stats'][mod.model_name] = mod.stats
return dict(results)