[primTorch] refs: isclose - throw error (#78922)
Just checked the condition but didn't throw.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78922
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 62d06af..ba6e2c9 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -883,15 +883,25 @@
atol: float = 1e-08,
equal_nan: bool = False,
) -> TensorLikeType:
- if a.dtype != b.dtype:
- msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format(
+ check(
+ a.dtype == b.dtype,
+ lambda: "torch.isclose: Attempting to compare tensors of different dtypes {0} and {1}!".format(
a.dtype, b.dtype
- )
- raise ValueError(a, b)
- if rtol < 0:
- msg = "rtol must be greater than or equal to zero, but got {0}!".format(rtol)
- if atol < 0:
- msg = "atol must be greater than or equal to zero, but got {0}!".format(atol)
+ ),
+ ValueError,
+ )
+ check(
+ rtol >= 0,
+ lambda: "torch.isclose: rtol must be greater than or equal to zero, but got {0}!".format(
+ rtol
+ ),
+ )
+ check(
+ atol >= 0,
+ lambda: "torch.isclose: atol must be greater than or equal to zero, but got {0}!".format(
+ atol
+ ),
+ )
close = eq(a, b)
if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fdd2526..f7d72cb 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2931,6 +2931,19 @@
yield SampleInput(lhs, args=(rhs,),
kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan))
+
+def error_inputs_isclose(op, device, **kwargs):
+ make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+
+ yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}),
+ error_type=RuntimeError,
+ error_regex='rtol must be greater than or equal to zero')
+
+ yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}),
+ error_type=RuntimeError,
+ error_regex='atol must be greater than or equal to zero')
+
+
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return (SampleInput(make_arg((1, 2))),
@@ -13006,6 +13019,7 @@
ref=np.isclose,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_isclose,
+ error_inputs_func=error_inputs_isclose,
supports_autograd=False,
supports_out=False,
supports_rhs_python_scalar=False,