Update accuracy checking for nan, floats (#108202)
Fixes inference accuracy for `doctr_reco_predictor` and `pyhpc_turbulent_kinetic_energy`.
For the `same(float, float)` comparison we weren't going through the more rigorous tensor comparison path which takes into account the fp64 base results.
Also return True when fp64 base result are not well formed (nan).
I debugged these models and the source of divergence were innocuous:
`doctr_reco_predictor` - can be fixed by turning off layout optimization, decomp for batch norm
`pyhpc_turbulent_kinetic_energy` - divergence caused because fused kernel keeps precision in fp32 instead of casting back and forth from/to fp32/bf16. Fused kernel is better precision, anyway.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108202
Approved by: https://github.com/jansel
diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py
index 0fa352b..64b29e3 100755
--- a/benchmarks/dynamo/torchbench.py
+++ b/benchmarks/dynamo/torchbench.py
@@ -233,6 +233,7 @@
"doctr_reco_predictor",
"Super_SloMo",
"tts_angular",
+ "pyhpc_turbulent_kinetic_energy",
}
# models in canary_models that we should run anyway
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 44769cf..899c00d 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -993,10 +993,15 @@
log_error("Accuracy failed for key name %s", k)
return False
return True
- elif isinstance(ref, torch.Tensor):
+ elif isinstance(ref, (torch.Tensor, float)):
assert not isinstance(ref, torch._subclasses.FakeTensor)
assert not isinstance(res, torch._subclasses.FakeTensor)
+ def to_tensor(t):
+ return t if isinstance(t, torch.Tensor) else torch.tensor(t)
+
+ ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref))
+
if ref.is_sparse:
assert res.is_sparse
ref = ref.to_dense()
@@ -1043,6 +1048,12 @@
# Check error from fp64 version
if fp64_ref.dtype == torch.float64:
ref_error = rmse(fp64_ref, ref).item()
+ # ref unable to produce this with stable numerics in this precision, ignore
+ if math.isnan(ref_error):
+ log.warning(
+ "Found nan in reference. Consider running in higher precision."
+ )
+
res_error = rmse(fp64_ref, res).item()
multiplier = 2.0
@@ -1080,13 +1091,6 @@
if not r:
log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res)
return r
- elif isinstance(ref, float):
- r = math.isclose(ref, res, rel_tol=tol, abs_tol=tol)
- if not r:
- log_error(
- "Accuracy failed (float): %s != %s (within tol=%s)", ref, res, tol
- )
- return r
elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
if relax_numpy_equality and not (
is_numpy_int_type(res) or is_numpy_float_type(res)