[primTorch] refs: isclose - throw error (#78922)

Just checked the condition but didn't throw.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78922
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 62d06af..ba6e2c9 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -883,15 +883,25 @@
     atol: float = 1e-08,
     equal_nan: bool = False,
 ) -> TensorLikeType:
-    if a.dtype != b.dtype:
-        msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format(
+    check(
+        a.dtype == b.dtype,
+        lambda: "torch.isclose: Attempting to compare tensors of different dtypes {0} and {1}!".format(
             a.dtype, b.dtype
-        )
-        raise ValueError(a, b)
-    if rtol < 0:
-        msg = "rtol must be greater than or equal to zero, but got {0}!".format(rtol)
-    if atol < 0:
-        msg = "atol must be greater than or equal to zero, but got {0}!".format(atol)
+        ),
+        ValueError,
+    )
+    check(
+        rtol >= 0,
+        lambda: "torch.isclose: rtol must be greater than or equal to zero, but got {0}!".format(
+            rtol
+        ),
+    )
+    check(
+        atol >= 0,
+        lambda: "torch.isclose: atol must be greater than or equal to zero, but got {0}!".format(
+            atol
+        ),
+    )
 
     close = eq(a, b)
     if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fdd2526..f7d72cb 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2931,6 +2931,19 @@
         yield SampleInput(lhs, args=(rhs,),
                           kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan))
 
+
+def error_inputs_isclose(op, device, **kwargs):
+    make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+
+    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}),
+                     error_type=RuntimeError,
+                     error_regex='rtol must be greater than or equal to zero')
+
+    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}),
+                     error_type=RuntimeError,
+                     error_regex='atol must be greater than or equal to zero')
+
+
 def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
     return (SampleInput(make_arg((1, 2))),
@@ -13006,6 +13019,7 @@
                     ref=np.isclose,
                     dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
                     sample_inputs_func=sample_inputs_isclose,
+                    error_inputs_func=error_inputs_isclose,
                     supports_autograd=False,
                     supports_out=False,
                     supports_rhs_python_scalar=False,