[ao] Added statistical threshold arg in Outlier Detector (#81174)

Summary: The outlier detector has a feature where it's able to notify
the user if below the whole set of batches that passed through were used
in Outlier calculation, which mainly happens as a result of 0-errors.
This changes the code so that instead of comparing against a value like
30 as we were before, we now let the user pass in an optional fractional
value and if the ratio of the batches used was below that value, the
detector alerts the user.

Test Plan: python test/test_quantization.py TestFxDetectOutliers

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81174
Approved by: https://github.com/andrewor14
diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py
index a932160..8039947 100644
--- a/test/quantization/fx/test_model_report_fx.py
+++ b/test/quantization/fx/test_model_report_fx.py
@@ -1438,10 +1438,6 @@
                 # get the info for the specific module
                 module_dict = outlier_dict[module_fqn]
 
-                # because we only ran once, all batches run should say statisitically insignificant amount of data
-                sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
-                self.assertEqual(sum(sufficient_batches_info), 0)
-
                 # there really should not be any outliers since we used a normal distribution to perform this calculation
                 outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
                 self.assertEqual(sum(outlier_info), 0)
@@ -1493,10 +1489,6 @@
                 # get the info for the specific module
                 module_dict = outlier_dict[module_fqn]
 
-                # because we only ran once, all batches run should say statisitically insignificant amount of data
-                sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
-                self.assertEqual(sum(sufficient_batches_info), 0)
-
                 # everything should be an outlier because we said that the max should be equal to the min for all of them
                 # however we will just test and say most should be in case we have several 0 channel values
                 outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
@@ -1554,7 +1546,7 @@
 
                 # because we ran 30 times, we should have at least a couple be significant
                 # could be less because some channels could possibly be all 0
-                sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
+                sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY]
                 assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2
 
                 # half of them should be outliers, because we set a really high value every 2 channels
diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py
index fe5ac6d..c1e2a3f 100644
--- a/torch/ao/quantization/fx/_model_report/detector.py
+++ b/torch/ao/quantization/fx/_model_report/detector.py
@@ -938,6 +938,11 @@
         reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
             Should be between 0 and 1
             Default: 0.975
+        fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
+            If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
+            regardless of whether we detected outliers or not in channel to take a closer look at channel results
+            Should be between 0 and 1
+            Default: 0.95
         ch_axis (int, optional): The channel axis being observed to determine input weight equalization
             Default: 1
 
@@ -949,6 +954,10 @@
     * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
         Should be between 0 and 1
 
+    * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
+        Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
+        Should be between 0 and 1
+
     * :attr:`ch_axis`: The channel axis being observed to determine outliers
 
     * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
@@ -960,20 +969,28 @@
     # names for dict keys
     OUTLIER_KEY = "outliers_detected"
     NUM_BATCHES_KEY = "batches_used"
-    SUFFICIENT_BATCHES_KEY = "sufficient_batches"
+    IS_SUFFICIENT_BATCHES_KEY = "is_sufficient_batches"
     COMP_METRIC_KEY = "percentile_ratios"
     RATIO_THRES_KEY = "ratio_threshold"
     REF_PERCENTILE_KEY = "reference_percentile"
     CHANNEL_AXIS_KEY = "channel_axis"
     MAX_VALS_KEY = "per_channel_max"
 
-    def __init__(self, ratio_threshold: float = 3.5, reference_percentile: float = 0.975, ch_axis: int = 1):
+    def __init__(
+        self,
+        ratio_threshold: float = 3.5,
+        reference_percentile: float = 0.975,
+        fraction_batches_used_threshold: float = 0.95,
+        ch_axis: int = 1,
+    ):
         # initialize the variables of interest
         self.ratio_threshold = ratio_threshold
 
         # make sure passed in percentile is valid
         assert reference_percentile >= 0 and reference_percentile <= 1
+        assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
         self.reference_percentile = reference_percentile
+        self.fraction_batches_used_threshold = fraction_batches_used_threshold
         self.ch_axis = ch_axis
 
     def get_detector_name(self) -> str:
@@ -1051,7 +1068,9 @@
 
         return obs_fqn_to_info
 
-    def _calculate_outlier_info(self, percentile_ratios: torch.Tensor, counted_batches: torch.Tensor) -> Dict[str, List[bool]]:
+    def _calculate_outlier_info(
+            self, percentile_ratios: torch.Tensor, counted_batches: torch.Tensor, total_batches: int
+    ) -> Dict[str, List[bool]]:
         r"""
         Gives info on whether the percentile ratios cacluated would be considered outliers
         Also gives information on whether the collected data is statistically significant to make this claim
@@ -1059,20 +1078,24 @@
         Args:
             percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
             counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
+            total_batches (int): The total number of batches that passed through observer in this epoch
 
         Returns a dictionary mapping:
             "outliers_detected" : list of bools per channel that are true if it is considered an outlier
-            "above_sample_threshold_count": if the per channel calculation had at least 30 samples (a random threshold)
+            "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
+                where o_r = counted_batches / total_batches
         """
-        outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.SUFFICIENT_BATCHES_KEY: []}
+        outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
 
         # get both as flattened lists for easy mapping
         ratios_list: List = percentile_ratios.tolist()
         num_batches_list: List = counted_batches.tolist()
 
         # calculate whether channels were statistically significant
-        significant_size = [batch_size >= 30 for batch_size in num_batches_list]
-        outlier_dict[self.SUFFICIENT_BATCHES_KEY] = significant_size
+        significant_size = [
+            batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
+        ]
+        outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
 
         # calculate for each channel whether it's an outlier or not based on ratio
         outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
@@ -1092,7 +1115,7 @@
         Returns a dict mapping relavent module fqns to:
             whether there were outliers found in activation before
             the number of batches used for each channel
-            whether the number of applicable batches is above a minimum set threshold (30)
+            whether fraction of applicable batches used is above fraction_batches_used_threshold
             their p_r metric compared to the threshold
             the threshold used to make the recommendation
             the reference_percentile used to make the recommendation
@@ -1111,6 +1134,7 @@
                 # get the number of batches and calculated ratio thresholds
                 num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
                 average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
+                total_batches: int = pre_obs.num_batches_tracked
 
                 # also get the max values
                 max_vals: torch.Tensor = pre_obs.max_val
@@ -1129,7 +1153,7 @@
                         # if it's less than 1 we have the flip it as well
                         average_ratios[index] = 1 / ratio_val
 
-                outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches)
+                outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
 
                 # calculate whether ratios were outliers
                 info_dict[fqn] = {
@@ -1139,7 +1163,7 @@
                     self.COMP_METRIC_KEY: average_ratios,
                     self.NUM_BATCHES_KEY: num_batches,
                     self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
-                    self.SUFFICIENT_BATCHES_KEY: outlier_calcs[self.SUFFICIENT_BATCHES_KEY],
+                    self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
                     self.MAX_VALS_KEY: max_vals
                 }
 
@@ -1159,7 +1183,7 @@
             Dictionary mapping modules of interest to:
                 whether there were outliers found in activation before
                 the number of batches used for each channel
-                whether the number of applicable batches is above a minimum set threshold (30)
+                whether fraction of applicable batches used is above fraction_batches_used_threshold
                 their p_r metric compared to the threshold
                 the threshold used to make the recommendation
                 the reference_percentile used to make the recommendation