[primTorch] Rewrite nan_to_num ref in terms of aten functions (#93952)

This de-duplicates `_refs.nan_to_num` with the inductor decomposition
and simplifies it to not reimplement `isnan`, `isposinf` and `isneginf`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93952
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index b4c3ebb..4511779 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -420,6 +420,7 @@
     "isinf",
     "isposinf",
     "isneginf",
+    "nan_to_num",
     "mT",
     "mH",
 }
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index e3b9b86..1d91a1e 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -241,6 +241,7 @@
             aten.mse_loss,
             aten.mse_loss_backward,
             aten.mv,
+            aten.nan_to_num,
             aten.narrow,
             aten.native_batch_norm,
             aten._native_batch_norm_legit,
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 4e8d597..b59d20f 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -7,7 +7,6 @@
 import torch._decomp as decomp
 from torch import Tensor
 from torch._decomp import core_aten_decompositions, get_decompositions
-from torch._prims_common import is_boolean_dtype, is_integer_dtype
 from torch.utils._mode_utils import no_dispatch
 
 from . import config, utils
@@ -321,26 +320,6 @@
     return b - a
 
 
-@register_decomposition([aten.nan_to_num])
-def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
-    if is_boolean_dtype(x.dtype) or is_integer_dtype(x.dtype):
-        return x
-
-    if nan is None:
-        nan = 0.0
-    if posinf is None:
-        posinf = torch.finfo(x.dtype).max
-    if neginf is None:
-        neginf = torch.finfo(x.dtype).min
-    nan, posinf, neginf = (
-        torch.tensor(v, dtype=x.dtype, device=x.device) for v in (nan, posinf, neginf)
-    )
-    x = torch.where(x != x, nan, x)
-    x = torch.where(x == float("inf"), posinf, x)
-    x = torch.where(x == float("-inf"), neginf, x)
-    return x
-
-
 @register_decomposition([aten.all.default])
 def all(input):
     return torch.logical_not(torch.any(torch.logical_not(input)))
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index c8adabf..68bd53e 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -83,6 +83,8 @@
     "index_fill_",
     "isfinite",
     "isinf",
+    "isposinf",
+    "isneginf",
     "isnan",
     "isreal",
     "i0",
@@ -736,7 +738,7 @@
     assert isinstance(a, TensorLike)
 
     if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
-        return clone(a)
+        return a.clone()
 
     if nan is None:
         nan = 0.0
@@ -747,14 +749,9 @@
     if neginf is None:
         neginf = torch.finfo(a.dtype).min
 
-    result = where(isnan(a), nan, a)
-
-    is_neg = signbit(a)
-    is_neginf = bitwise_and(isinf(a), is_neg)
-    result = where(is_neginf, neginf, result)
-
-    is_posinf = bitwise_and(isinf(a), bitwise_not(is_neg))
-    result = where(is_posinf, posinf, result)
+    result = torch.where(torch.isnan(a), nan, a)  # type: ignore[call-overload]
+    result = torch.where(torch.isneginf(a), neginf, result)  # type: ignore[call-overload]
+    result = torch.where(torch.isposinf(a), posinf, result)  # type: ignore[call-overload]
     return result