[primTorch] Rewrite nan_to_num ref in terms of aten functions (#93952)
This de-duplicates `_refs.nan_to_num` with the inductor decomposition
and simplifies it to not reimplement `isnan`, `isposinf` and `isneginf`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93952
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index b4c3ebb..4511779 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -420,6 +420,7 @@
"isinf",
"isposinf",
"isneginf",
+ "nan_to_num",
"mT",
"mH",
}
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index e3b9b86..1d91a1e 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -241,6 +241,7 @@
aten.mse_loss,
aten.mse_loss_backward,
aten.mv,
+ aten.nan_to_num,
aten.narrow,
aten.native_batch_norm,
aten._native_batch_norm_legit,
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 4e8d597..b59d20f 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -7,7 +7,6 @@
import torch._decomp as decomp
from torch import Tensor
from torch._decomp import core_aten_decompositions, get_decompositions
-from torch._prims_common import is_boolean_dtype, is_integer_dtype
from torch.utils._mode_utils import no_dispatch
from . import config, utils
@@ -321,26 +320,6 @@
return b - a
-@register_decomposition([aten.nan_to_num])
-def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
- if is_boolean_dtype(x.dtype) or is_integer_dtype(x.dtype):
- return x
-
- if nan is None:
- nan = 0.0
- if posinf is None:
- posinf = torch.finfo(x.dtype).max
- if neginf is None:
- neginf = torch.finfo(x.dtype).min
- nan, posinf, neginf = (
- torch.tensor(v, dtype=x.dtype, device=x.device) for v in (nan, posinf, neginf)
- )
- x = torch.where(x != x, nan, x)
- x = torch.where(x == float("inf"), posinf, x)
- x = torch.where(x == float("-inf"), neginf, x)
- return x
-
-
@register_decomposition([aten.all.default])
def all(input):
return torch.logical_not(torch.any(torch.logical_not(input)))
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index c8adabf..68bd53e 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -83,6 +83,8 @@
"index_fill_",
"isfinite",
"isinf",
+ "isposinf",
+ "isneginf",
"isnan",
"isreal",
"i0",
@@ -736,7 +738,7 @@
assert isinstance(a, TensorLike)
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- return clone(a)
+ return a.clone()
if nan is None:
nan = 0.0
@@ -747,14 +749,9 @@
if neginf is None:
neginf = torch.finfo(a.dtype).min
- result = where(isnan(a), nan, a)
-
- is_neg = signbit(a)
- is_neginf = bitwise_and(isinf(a), is_neg)
- result = where(is_neginf, neginf, result)
-
- is_posinf = bitwise_and(isinf(a), bitwise_not(is_neg))
- result = where(is_posinf, posinf, result)
+ result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload]
+ result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload]
+ result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload]
return result