[ao] Added statistical threshold arg in Outlier Detector (#81174)
Summary: The outlier detector has a feature where it's able to notify
the user if below the whole set of batches that passed through were used
in Outlier calculation, which mainly happens as a result of 0-errors.
This changes the code so that instead of comparing against a value like
30 as we were before, we now let the user pass in an optional fractional
value and if the ratio of the batches used was below that value, the
detector alerts the user.
Test Plan: python test/test_quantization.py TestFxDetectOutliers
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81174
Approved by: https://github.com/andrewor14
diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py
index a932160..8039947 100644
--- a/test/quantization/fx/test_model_report_fx.py
+++ b/test/quantization/fx/test_model_report_fx.py
@@ -1438,10 +1438,6 @@
# get the info for the specific module
module_dict = outlier_dict[module_fqn]
- # because we only ran once, all batches run should say statisitically insignificant amount of data
- sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
- self.assertEqual(sum(sufficient_batches_info), 0)
-
# there really should not be any outliers since we used a normal distribution to perform this calculation
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
self.assertEqual(sum(outlier_info), 0)
@@ -1493,10 +1489,6 @@
# get the info for the specific module
module_dict = outlier_dict[module_fqn]
- # because we only ran once, all batches run should say statisitically insignificant amount of data
- sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
- self.assertEqual(sum(sufficient_batches_info), 0)
-
# everything should be an outlier because we said that the max should be equal to the min for all of them
# however we will just test and say most should be in case we have several 0 channel values
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
@@ -1554,7 +1546,7 @@
# because we ran 30 times, we should have at least a couple be significant
# could be less because some channels could possibly be all 0
- sufficient_batches_info = module_dict[OutlierDetector.SUFFICIENT_BATCHES_KEY]
+ sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY]
assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2
# half of them should be outliers, because we set a really high value every 2 channels
diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py
index fe5ac6d..c1e2a3f 100644
--- a/torch/ao/quantization/fx/_model_report/detector.py
+++ b/torch/ao/quantization/fx/_model_report/detector.py
@@ -938,6 +938,11 @@
reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
Should be between 0 and 1
Default: 0.975
+ fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
+ If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
+ regardless of whether we detected outliers or not in channel to take a closer look at channel results
+ Should be between 0 and 1
+ Default: 0.95
ch_axis (int, optional): The channel axis being observed to determine input weight equalization
Default: 1
@@ -949,6 +954,10 @@
* :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
Should be between 0 and 1
+ * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
+ Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
+ Should be between 0 and 1
+
* :attr:`ch_axis`: The channel axis being observed to determine outliers
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
@@ -960,20 +969,28 @@
# names for dict keys
OUTLIER_KEY = "outliers_detected"
NUM_BATCHES_KEY = "batches_used"
- SUFFICIENT_BATCHES_KEY = "sufficient_batches"
+ IS_SUFFICIENT_BATCHES_KEY = "is_sufficient_batches"
COMP_METRIC_KEY = "percentile_ratios"
RATIO_THRES_KEY = "ratio_threshold"
REF_PERCENTILE_KEY = "reference_percentile"
CHANNEL_AXIS_KEY = "channel_axis"
MAX_VALS_KEY = "per_channel_max"
- def __init__(self, ratio_threshold: float = 3.5, reference_percentile: float = 0.975, ch_axis: int = 1):
+ def __init__(
+ self,
+ ratio_threshold: float = 3.5,
+ reference_percentile: float = 0.975,
+ fraction_batches_used_threshold: float = 0.95,
+ ch_axis: int = 1,
+ ):
# initialize the variables of interest
self.ratio_threshold = ratio_threshold
# make sure passed in percentile is valid
assert reference_percentile >= 0 and reference_percentile <= 1
+ assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
self.reference_percentile = reference_percentile
+ self.fraction_batches_used_threshold = fraction_batches_used_threshold
self.ch_axis = ch_axis
def get_detector_name(self) -> str:
@@ -1051,7 +1068,9 @@
return obs_fqn_to_info
- def _calculate_outlier_info(self, percentile_ratios: torch.Tensor, counted_batches: torch.Tensor) -> Dict[str, List[bool]]:
+ def _calculate_outlier_info(
+ self, percentile_ratios: torch.Tensor, counted_batches: torch.Tensor, total_batches: int
+ ) -> Dict[str, List[bool]]:
r"""
Gives info on whether the percentile ratios cacluated would be considered outliers
Also gives information on whether the collected data is statistically significant to make this claim
@@ -1059,20 +1078,24 @@
Args:
percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
+ total_batches (int): The total number of batches that passed through observer in this epoch
Returns a dictionary mapping:
"outliers_detected" : list of bools per channel that are true if it is considered an outlier
- "above_sample_threshold_count": if the per channel calculation had at least 30 samples (a random threshold)
+ "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
+ where o_r = counted_batches / total_batches
"""
- outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.SUFFICIENT_BATCHES_KEY: []}
+ outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
# get both as flattened lists for easy mapping
ratios_list: List = percentile_ratios.tolist()
num_batches_list: List = counted_batches.tolist()
# calculate whether channels were statistically significant
- significant_size = [batch_size >= 30 for batch_size in num_batches_list]
- outlier_dict[self.SUFFICIENT_BATCHES_KEY] = significant_size
+ significant_size = [
+ batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
+ ]
+ outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
# calculate for each channel whether it's an outlier or not based on ratio
outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
@@ -1092,7 +1115,7 @@
Returns a dict mapping relavent module fqns to:
whether there were outliers found in activation before
the number of batches used for each channel
- whether the number of applicable batches is above a minimum set threshold (30)
+ whether fraction of applicable batches used is above fraction_batches_used_threshold
their p_r metric compared to the threshold
the threshold used to make the recommendation
the reference_percentile used to make the recommendation
@@ -1111,6 +1134,7 @@
# get the number of batches and calculated ratio thresholds
num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
+ total_batches: int = pre_obs.num_batches_tracked
# also get the max values
max_vals: torch.Tensor = pre_obs.max_val
@@ -1129,7 +1153,7 @@
# if it's less than 1 we have the flip it as well
average_ratios[index] = 1 / ratio_val
- outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches)
+ outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
# calculate whether ratios were outliers
info_dict[fqn] = {
@@ -1139,7 +1163,7 @@
self.COMP_METRIC_KEY: average_ratios,
self.NUM_BATCHES_KEY: num_batches,
self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
- self.SUFFICIENT_BATCHES_KEY: outlier_calcs[self.SUFFICIENT_BATCHES_KEY],
+ self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
self.MAX_VALS_KEY: max_vals
}
@@ -1159,7 +1183,7 @@
Dictionary mapping modules of interest to:
whether there were outliers found in activation before
the number of batches used for each channel
- whether the number of applicable batches is above a minimum set threshold (30)
+ whether fraction of applicable batches used is above fraction_batches_used_threshold
their p_r metric compared to the threshold
the threshold used to make the recommendation
the reference_percentile used to make the recommendation