Add option to warn if elements in a Compare table are suspect (#41011)
Summary:
This PR adds a `.highlight_warnings()` method to `Compare`, which will include a `(! XX%)` next to measurements with high variance to highlight that fact. For example:
```
[------------- Record function overhead ------------]
| lstm_jit | resnet50_jit
1 threads: ------------------------------------------
with_rec_fn | 650 | 8600
without_rec_fn | 660 | 8000
2 threads: ------------------------------------------
with_rec_fn | 360 | 4200
without_rec_fn | 350 | 4000
4 threads: ------------------------------------------
with_rec_fn | 250 | 2100
without_rec_fn | 260 | 2000
8 threads: ------------------------------------------
with_rec_fn | 200 (! 6%) | 1200
without_rec_fn | 210 (! 6%) | 1100
16 threads: -----------------------------------------
with_rec_fn | 220 (! 8%) | 900 (! 5%)
without_rec_fn | 200 (! 5%) | 1000 (! 7%)
32 threads: -----------------------------------------
with_rec_fn | 1000 (! 7%) | 920
without_rec_fn | 1000 (! 6%) | 900 (! 6%)
Times are in milliseconds (ms).
(! XX%) Measurement has high variance, where XX is the median / IQR * 100.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41011
Differential Revision: D22412905
Pulled By: robieta
fbshipit-source-id: 2c90e719d9a5a1c0267ed113dd1b1b1738fa8269
diff --git a/benchmarks/experimental_components/utils/compare.py b/benchmarks/experimental_components/utils/compare.py
index c2cc816..16429ab 100644
--- a/benchmarks/experimental_components/utils/compare.py
+++ b/benchmarks/experimental_components/utils/compare.py
@@ -1,7 +1,7 @@
"""Display class to aggregate and print the results of many measurements."""
import collections
import itertools as it
-from typing import List, Tuple
+from typing import List, Optional, Tuple
import numpy as np
@@ -25,12 +25,14 @@
time_scale: float,
time_unit: str,
trim_significant_figures: bool,
+ highlight_warnings: bool,
):
self._grouped_results = grouped_results
self._flat_results = list(it.chain(*grouped_results))
self._time_scale = time_scale
self._time_unit = time_unit
self._trim_significant_figures = trim_significant_figures
+ self._highlight_warnings = highlight_warnings and any(r.has_warnings for r in self._flat_results)
leading_digits = [
int(np.ceil(np.log10(r.median / self._time_scale)))
for r in self._flat_results
@@ -41,15 +43,17 @@
for digits, m in zip(leading_digits, self._flat_results)
) if self._trim_significant_figures else 1
length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
- self._template = f"{{:>{length}.{decimal_digits}f}}"
+ self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"
def get_results_for(self, group):
return self._grouped_results[group]
- def num_to_str(self, value: float, estimated_sigfigs: int):
+ def num_to_str(self, value: float, estimated_sigfigs: int, spread: Optional[float]):
if self._trim_significant_figures:
value = common.trim_sigfig(value, estimated_sigfigs)
- return self._template.format(value)
+ return self._template.format(
+ value,
+ f" (! {spread:.0f}%)" if self._highlight_warnings and spread is not None else "")
class _Row(object):
@@ -74,7 +78,11 @@
env = env.ljust(self._env_str_len + 4)
output = [" " + env + self._results[0].as_row_name]
for m, col in zip(self._results, self._columns):
- output.append(col.num_to_str(m.median / self._time_scale, m.significant_figures))
+ output.append(col.num_to_str(
+ m.median / self._time_scale,
+ m.significant_figures,
+ m.median / m._iqr if m.has_warnings else None
+ ))
return output
@staticmethod
@@ -110,12 +118,13 @@
class Table(object):
def __init__(self, results: List[common.Measurement], colorize: bool,
- trim_significant_figures: bool):
+ trim_significant_figures: bool, highlight_warnings: bool):
assert len(set(r.label for r in results)) == 1
self.results = results
self._colorize = colorize
self._trim_significant_figures = trim_significant_figures
+ self._highlight_warnings = highlight_warnings
self.label = results[0].label
self.time_unit, self.time_scale = common.select_unit(
min(r.median for r in results)
@@ -182,7 +191,8 @@
column = _Column(
grouped_results=grouped_results, time_scale=self.time_scale,
time_unit=self.time_unit,
- trim_significant_figures=self._trim_significant_figures)
+ trim_significant_figures=self._trim_significant_figures,
+ highlight_warnings=self._highlight_warnings,)
columns.append(column)
rows, columns = tuple(rows), tuple(columns)
@@ -206,7 +216,10 @@
finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths)))
print("[" + (" " + self.label + " ").center(overall_width - 2, "-") + "]")
print("\n".join(finalized_columns))
- print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).", "\n" * 2)
+ print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).")
+ if self._highlight_warnings and any(r.has_warnings for r in self.results):
+ print("(! XX%) Measurement has high variance, where XX is the median / IQR * 100.")
+ print("\n")
class Compare(object):
@@ -215,6 +228,7 @@
self.extend_results(results)
self._trim_significant_figures = False
self._colorize = False
+ self._highlight_warnings = False
def extend_results(self, results):
for r in results:
@@ -230,6 +244,9 @@
def colorize(self):
self._colorize = True
+ def highlight_warnings(self):
+ self._highlight_warnings = True
+
def print(self):
self._render()
@@ -246,5 +263,6 @@
return grouped_results
def _layout(self, results: List[common.Measurement]):
- table = Table(results, self._colorize, self._trim_significant_figures)
+ table = Table(results, self._colorize, self._trim_significant_figures,
+ self._highlight_warnings)
table.render()