Add option to warn if elements in a Compare table are suspect (#41011)

Summary:
This PR adds a `.highlight_warnings()` method to `Compare`, which will include a `(! XX%)` next to measurements with high variance to highlight that fact. For example:
```
[------------- Record function overhead ------------]
                      |    lstm_jit   |  resnet50_jit
1 threads: ------------------------------------------
      with_rec_fn     |   650         |  8600
      without_rec_fn  |   660         |  8000
2 threads: ------------------------------------------
      with_rec_fn     |   360         |  4200
      without_rec_fn  |   350         |  4000
4 threads: ------------------------------------------
      with_rec_fn     |   250         |  2100
      without_rec_fn  |   260         |  2000
8 threads: ------------------------------------------
      with_rec_fn     |   200 (! 6%)  |  1200
      without_rec_fn  |   210 (! 6%)  |  1100
16 threads: -----------------------------------------
      with_rec_fn     |   220 (! 8%)  |   900 (! 5%)
      without_rec_fn  |   200 (! 5%)  |  1000 (! 7%)
32 threads: -----------------------------------------
      with_rec_fn     |  1000 (! 7%)  |   920
      without_rec_fn  |  1000 (! 6%)  |   900 (! 6%)

Times are in milliseconds (ms).
(! XX%) Measurement has high variance, where XX is the median / IQR * 100.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41011

Differential Revision: D22412905

Pulled By: robieta

fbshipit-source-id: 2c90e719d9a5a1c0267ed113dd1b1b1738fa8269
diff --git a/benchmarks/experimental_components/utils/compare.py b/benchmarks/experimental_components/utils/compare.py
index c2cc816..16429ab 100644
--- a/benchmarks/experimental_components/utils/compare.py
+++ b/benchmarks/experimental_components/utils/compare.py
@@ -1,7 +1,7 @@
 """Display class to aggregate and print the results of many measurements."""
 import collections
 import itertools as it
-from typing import List, Tuple
+from typing import List, Optional, Tuple
 
 import numpy as np
 
@@ -25,12 +25,14 @@
         time_scale: float,
         time_unit: str,
         trim_significant_figures: bool,
+        highlight_warnings: bool,
     ):
         self._grouped_results = grouped_results
         self._flat_results = list(it.chain(*grouped_results))
         self._time_scale = time_scale
         self._time_unit = time_unit
         self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = highlight_warnings and any(r.has_warnings for r in self._flat_results)
         leading_digits = [
             int(np.ceil(np.log10(r.median / self._time_scale)))
             for r in self._flat_results
@@ -41,15 +43,17 @@
             for digits, m in zip(leading_digits, self._flat_results)
         ) if self._trim_significant_figures else 1
         length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
-        self._template = f"{{:>{length}.{decimal_digits}f}}"
+        self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"
 
     def get_results_for(self, group):
         return self._grouped_results[group]
 
-    def num_to_str(self, value: float, estimated_sigfigs: int):
+    def num_to_str(self, value: float, estimated_sigfigs: int, spread: Optional[float]):
         if self._trim_significant_figures:
             value = common.trim_sigfig(value, estimated_sigfigs)
-        return self._template.format(value)
+        return self._template.format(
+            value,
+            f" (! {spread:.0f}%)" if self._highlight_warnings and spread is not None else "")
 
 
 class _Row(object):
@@ -74,7 +78,11 @@
         env = env.ljust(self._env_str_len + 4)
         output = ["  " + env + self._results[0].as_row_name]
         for m, col in zip(self._results, self._columns):
-            output.append(col.num_to_str(m.median / self._time_scale, m.significant_figures))
+            output.append(col.num_to_str(
+                m.median / self._time_scale,
+                m.significant_figures,
+                m.median / m._iqr if m.has_warnings else None
+            ))
         return output
 
     @staticmethod
@@ -110,12 +118,13 @@
 
 class Table(object):
     def __init__(self, results: List[common.Measurement], colorize: bool,
-                 trim_significant_figures: bool):
+                 trim_significant_figures: bool, highlight_warnings: bool):
         assert len(set(r.label for r in results)) == 1
 
         self.results = results
         self._colorize = colorize
         self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = highlight_warnings
         self.label = results[0].label
         self.time_unit, self.time_scale = common.select_unit(
             min(r.median for r in results)
@@ -182,7 +191,8 @@
             column = _Column(
                 grouped_results=grouped_results, time_scale=self.time_scale,
                 time_unit=self.time_unit,
-                trim_significant_figures=self._trim_significant_figures)
+                trim_significant_figures=self._trim_significant_figures,
+                highlight_warnings=self._highlight_warnings,)
             columns.append(column)
 
         rows, columns = tuple(rows), tuple(columns)
@@ -206,7 +216,10 @@
             finalized_columns.append("  |  ".join(row.finalize_column_strings(string_row, col_widths)))
         print("[" + (" " + self.label + " ").center(overall_width - 2, "-") + "]")
         print("\n".join(finalized_columns))
-        print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).", "\n" * 2)
+        print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).")
+        if self._highlight_warnings and any(r.has_warnings for r in self.results):
+            print("(! XX%) Measurement has high variance, where XX is the median / IQR * 100.")
+        print("\n")
 
 
 class Compare(object):
@@ -215,6 +228,7 @@
         self.extend_results(results)
         self._trim_significant_figures = False
         self._colorize = False
+        self._highlight_warnings = False
 
     def extend_results(self, results):
         for r in results:
@@ -230,6 +244,9 @@
     def colorize(self):
         self._colorize = True
 
+    def highlight_warnings(self):
+        self._highlight_warnings = True
+
     def print(self):
         self._render()
 
@@ -246,5 +263,6 @@
         return grouped_results
 
     def _layout(self, results: List[common.Measurement]):
-        table = Table(results, self._colorize, self._trim_significant_figures)
+        table = Table(results, self._colorize, self._trim_significant_figures,
+                      self._highlight_warnings)
         table.render()