[ao] Add method in ModelReport to generate visualizer (#81589)
Summary: We created a ModelReportVisualizer class, and the primary
way it is envisioned that it is accessed is:
```
model_report_visualizer = model_reporter.generate_visualizer()
```
This method only works after reports have been generated and it takes in
the generated reports and reformats them to be ordered by module, into
the format required by the ModelReportVisualization. It then generates
the visualizer instance and returns that to the user.
Test Plan: python test/test_quantization.py TestFxModelReportClass.test_generate_visualizer
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81589
Approved by: https://github.com/andrewor14
diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py
index c4eef7f..67c0403 100644
--- a/test/quantization/fx/test_model_report_fx.py
+++ b/test/quantization/fx/test_model_report_fx.py
@@ -13,6 +13,7 @@
OutlierDetector,
)
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
+from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
from torch.ao.quantization.fx._model_report.model_report import ModelReport
from torch.ao.quantization.observer import HistogramObserver, default_per_channel_weight_observer
from torch.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
@@ -1033,7 +1034,7 @@
# prepare and callibrate two different instances of same model
# prepare the model
- example_input = torch.randn(1, 3, 3, 3)
+ example_input = model_full.get_example_inputs()[0]
current_backend = torch.backends.quantized.engine
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
@@ -1084,6 +1085,61 @@
# make sure we don't run into error for single report
model_single_report = model_report_single.generate_model_report(False)
+ @skipIfNoFBGEMM
+ def test_generate_visualizer(self):
+ """
+ Tests that the ModelReport class can properly create the ModelReportVisualizer instance
+ Checks that:
+ - Correct number of modules are represented
+ - Modules are sorted
+ - Correct number of features for each module
+ """
+ with override_quantized_engine('fbgemm'):
+ # set the backend for this test
+ torch.backends.quantized.engine = "fbgemm"
+ # test with multiple detectors
+ detector_set = set()
+ detector_set.add(OutlierDetector(reference_percentile=0.95))
+ detector_set.add(InputWeightEqualizationDetector(0.5))
+
+ model = self.TwoThreeOps()
+
+ # get tst model and callibrate
+ prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
+ model, detector_set, model.get_example_inputs()[0]
+ )
+
+ # now we actually callibrate the model
+ example_input = model.get_example_inputs()[0]
+ example_input = example_input.to(torch.float)
+
+ prepared_for_callibrate_model(example_input)
+
+ # try to visualize without generating report, should throw error
+ with self.assertRaises(Exception):
+ mod_rep_visualizaiton = mod_report.generate_visualizer()
+
+ # now get the report by running it through ModelReport instance
+ generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
+
+ # now we get the visualizer should not error
+ mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
+
+ # since we tested with outlier detector, which looks at every base level module
+ # should be six entries in the ordered dict
+ mod_fqns_to_features = mod_rep_visualizer.generated_reports
+
+ self.assertEqual(len(mod_fqns_to_features), 6)
+
+ # outlier detector has 9 feature per module
+ # input-weight has 12 features per module
+ # there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules
+ # all linears will be common
+ for module_fqn in mod_fqns_to_features:
+ if ".linear" in module_fqn:
+ linear_info = mod_fqns_to_features[module_fqn]
+ self.assertEqual(len(linear_info), 20)
+
class TestFxDetectInputWeightEqualization(QuantizationTestCase):
class LinearConv(torch.nn.Module):
diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py
index 1d4b855..25be9e4 100644
--- a/torch/ao/quantization/fx/_model_report/model_report.py
+++ b/torch/ao/quantization/fx/_model_report/model_report.py
@@ -1,5 +1,5 @@
from typing import Any, Dict, Set, Tuple
-
+from collections import OrderedDict
import torch
from torch.ao.quantization.fx._model_report.detector import (
DetectorBase,
@@ -8,9 +8,11 @@
DETECTOR_IS_POST_OBS_KEY,
DETECTOR_TARGET_NODE_KEY
)
+from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
from torch.ao.quantization.fx.graph_module import GraphModule
from torch.ao.quantization.observer import ObserverBase
+
class ModelReport:
r"""
The ModelReport class aims to provide users an easy way to diagnose issues that they run into
@@ -62,7 +64,13 @@
2.) Prepare your model with prepare_fx
3.) Call model_report.prepare_detailed_calibration to add relavent observers
4.) Callibrate your model with data
- 5.) Call model_report.generate_report to generate report and optionally remove added observers
+ 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
+ Optional
+ 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
+ 7.) To help in parsing report information and debugging, view report info as a:
+ - Table
+ - Histogram
+ - Line plot
Example (with QuantizationTracer):
>>> # get the necessary qconfig
@@ -87,6 +95,9 @@
>>> # finally we generate the reports and optionally remove the observers we inserted
>>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
+ >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
+ >>> model_report_visualizer = tracer_reporter.generate_visualizer()
+
"""
def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
@@ -110,10 +121,14 @@
for desired_report in self._desired_detector_names:
self._detector_name_to_observer_fqns[desired_report] = set([])
- # flags to ensure that we can only prepare and generate report once
+ # flags to ensure that we can only prepare and remove observers once
self._prepared_flag = False
self._removed_observers = False
+ # store the reports that we generated for visualization purposes
+ # intially empty since no reports generated
+ self._generated_reports: Dict[str, Dict] = {}
+
def get_desired_reports_names(self) -> Set[str]:
""" Returns a copy of the desired reports for viewing """
return self._desired_detector_names.copy()
@@ -237,7 +252,15 @@
Returns a mapping of each desired report name to a tuple with:
The textual summary of that report information
A dictionary containing relavent statistics or information for that report
+
+ Note:
+ Throws exception if we try to generate report on model we already removed observers from
+ Throws exception if we try to generate report without preparing for callibration
"""
+ # if we haven't prepped model for callibration, then we shouldn't generate report yet
+ if not self._prepared_flag:
+ raise Exception("Cannot generate report without preparing model for callibration")
+
# if we already removed the observers, we cannot generate report
if self._removed_observers:
raise Exception("Cannot generate report on model you already removed observers from")
@@ -275,5 +298,120 @@
# remember to recompile the model
self._model.recompile()
+ # save the generated reports for visualization purposes
+ saved_reports: Dict[str, Dict] = {
+ report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items()
+ }
+
+ self._generated_reports = saved_reports
+
# return the reports of interest
return reports_of_interest
+
+ def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
+ r"""
+ Takes in two dictionaries and ensures that any common keys between the two have the same
+ values.
+
+ Args:
+ info_dict_a (Dict): First dictionary we wish to compare
+ info_dict_b (Dict): Second dictionary we wish to compare
+
+ Returns True if all shared keys have same values, false otherwise
+ """
+ # get the set of keys for both
+ dict_a_keys: Set = set(info_dict_a.keys())
+ dict_b_keys: Set = set(info_dict_b.keys())
+
+ # get the insersection keys and check if same value for both dicts
+ intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
+
+ for key in intersecting_keys:
+ dict_a_val = info_dict_a[key]
+ dict_b_val = info_dict_b[key]
+
+ # if it's a tensor we have to handle seperately
+ if type(dict_a_val) == torch.Tensor:
+ # if dict_b_val not tensor, automatically false
+ if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0:
+ return False
+ else:
+ # for non-tensor vals
+ if dict_a_val != dict_b_val:
+ return False
+
+ # if no non matching shared keys found, return true
+ return True
+
+ def _reformat_reports_for_visualizer(self) -> OrderedDict:
+ r"""
+ Takes the generated reports and reformats them into the format that is desired by the
+ ModelReportVisualizer
+
+ Returns an OrderedDict mapping module_fqns to their features
+ """
+ # we want to reorder and reformat the information so it is ordered in terms of order
+ # found in the model
+
+ # first create new dict with all modules as keys and features under respective module
+ module_fqns_to_features: Dict[str, Dict] = {}
+
+ for report_name in self._generated_reports:
+ # get mod -> feature dict and go through
+ module_info = self._generated_reports[report_name]
+
+ for module_fqn in module_info:
+ # check if already in our accumulation dict
+ if module_fqn in module_fqns_to_features:
+ # we merge all the features together
+ new_info: Dict = module_info[module_fqn]
+ present_info: Dict = module_fqns_to_features[module_fqn]
+
+ # merge them together into the new unioned dict
+ # same features keys -> same info, so okay if override
+
+ # do safety check to make sure shared keys have same info
+ if self._is_same_info_for_same_key(new_info, present_info):
+ module_fqns_to_features[module_fqn] = {**new_info, **present_info}
+ else:
+ error_str = "You have the same key with different values across detectors. "
+ error_str += "Someone incorrectly implemented a detector with conflicting keys to exisiting detectors."
+ raise ValueError(error_str)
+ else:
+ # we just set it
+ module_fqns_to_features[module_fqn] = module_info[module_fqn]
+
+ # our ordered dict so that modules can be ordered in order of how they appear in model
+ features_by_module: OrderedDict[str, Dict] = OrderedDict()
+
+ # we loop through modules in graph in order
+ for fqn, module in self._model.named_modules():
+ # find that fqn in fqns_to_features
+ if fqn in module_fqns_to_features:
+ # add it to our ordered dict
+ features_by_module[fqn] = module_fqns_to_features[fqn]
+
+ # return the ordered dict of info we created
+ return features_by_module
+
+ def generate_visualizer(self) -> ModelReportVisualizer:
+ r"""
+ Generates a ModelReportVisualizer instance using the reports generated
+ by the generate_model_report() method.
+
+ Returns the generated ModelReportVisualizer instance initialized
+
+ Note:
+ Throws exception if attempt to get visualizers without generating report
+ """
+ # check if user has generated reports at least once
+ if len(self._generated_reports) == 0:
+ raise Exception("Unable to generate visualizers without first generating reports")
+
+ # get the ordered dict mapping modules to their full set of collected features / stats
+ module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
+
+ # create and return ModelReportVisualizer instance
+ visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features)
+
+ return visualizer
diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
index 33bd875..f51390d 100644
--- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
+++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
@@ -57,7 +57,7 @@
generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
can also be a dictionary generated in another manner, as long as format is same
"""
- pass
+ self.generated_reports = generated_reports
def get_all_unique_module_fqns(self) -> Set[str]:
r"""