Implements torch.isclose for complex tensors (#36456)
Summary:
Previously torch.isclose would RuntimeError when called on complex tensors. This update updates torch.isclose to run on complex tensors and be consistent with [NumPy](https://numpy.org/doc/1.18/reference/generated/numpy.isclose.html). However, NumPy's handling of NaN, -inf, and inf values is odd, so I adopted Python's [cmath.isclose](https://docs.python.org/3/library/cmath.html) behavior when dealing with them. See https://github.com/numpy/numpy/issues/15959 for more on NumPy's behavior.
While implementing complex isclose I also simplified the isclose algorithm to:
- A is close to B if A and B are equal, if equal_nan is true then NaN is equal to NaN
- If A and B are finite, then A is close to B if `abs(a - b) <= (atol + abs(rtol * b))`
This PR also documents torch.isclose, since it was undocumented, and adds multiple tests for its behavior to test_torch.py since it had no dedicated tests.
The PR leaves equal_nan=True with complex inputs an error for now, pending the outcome of https://github.com/numpy/numpy/issues/15959.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36456
Differential Revision: D21159853
Pulled By: mruberry
fbshipit-source-id: fb18fa7048e6104cc24f5ce308fdfb0ba5e4bb30
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 9529b5a..c30c77d 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -18,42 +18,60 @@
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
}
+// Note [closeness]
+// A number A is close to B when either:
+//
+// (1) A is equal to B, with NaNs comparing equal when equal_nan is true.
+// (2) The error abs(A - B) is finite and less than the max error
+// (atol + abs(rtol * B)).
+//
+// Note that this is consistent with NumPy's isclose but divergent from
+// Python's isclose, which computes the max error symmetrically as
+// max(rtol * max(abs(A), abs(B)), atol).
+// TODO: use bitwise operator overloads once we add them
+// TODO: revisit complex inputs and equal_nan=true after
+// https://github.com/numpy/numpy/issues/15959 is resolved
Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
- // TODO: use bitwise operator overloads once we add them
+ TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
+ TORCH_CHECK(!(self.is_complex() && equal_nan),
+ "isclose with equal_nan=True is not supported for complex inputs.");
- TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type())
+ // Checks that rtol and atol are non-negative
+ // Note: consistent with Python's isclose but divergent from NumPy's, which
+ // allows negative atol and rtol.
+ TORCH_CHECK(rtol >= 0, "rtol must be greater than or equal to zero, but got ", rtol);
+ TORCH_CHECK(atol >= 0, "atol must be greater than or equal to zero, but got ", atol);
- // The original formula `atol + rtol * other.abs()` works incorrectly when
- // `other` has integral dtype and `other == min_value` and `abs(min_value)` is negative:
- // std::abs(std::numeric_limits<int64_t>::lowest()) == std::numeric_limits<int64_t>::lowest() < 0
- auto max_error = atol + (rtol * other).abs();
+ // Computes equality closeness
+ Tensor close = self == other;
+ if (equal_nan && self.is_floating_point()) {
+ close.__ior__((self != self).__iand__(other != other));
+ }
- // `max_error` could be a float or double depending on the type of the input
- // tensors.
- // Specifically, if other is an int tensor, multiplying by rtol results in
- // float tensor.
- // It is also possible for parameters to be 'wrapped_number's, in which case
- // max_error could be promoted to double when actual error is still a float.
- Tensor actual_error;
- if (actual_error.scalar_type() != max_error.scalar_type()) {
- // To silence ASAN that does not like (x - std::numeric_limits<int64_t>::lowest())
- actual_error = (self - other.to(max_error.scalar_type())).abs();
+ // Note [closeness error computation]
+ // atol and rtol are provided as doubles, so the computation
+ // rtol * other will produce a float or complex tensor.
+ // When the difference (self - other) is compared to it then the
+ // tensor representing the difference will also be cast to float or complex.
+ // However, since (self - other) in uint8 is very likely to produce a
+ // negative value, this moves the cast forward so the difference is
+ // always computed in a float or complex type.
+ // If the values of the integer tensors cannot be exactly represented
+ // by the default scalar type then this may cause an incorrect result.
+
+ // Computes allowed and actual error
+ Tensor cast_other;
+ if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+ cast_other = other.to(at::get_default_dtype());
} else {
- actual_error = (self - other).abs();
+ cast_other = other;
}
+ Tensor allowed_error = atol + (rtol * cast_other).abs();
+ Tensor actual_error = (self - cast_other).abs();
- auto close = actual_error <= max_error;
+ // Computes finite closeness
+ close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
- if (isFloatingType(self.scalar_type()) && isFloatingType(other.scalar_type())) {
- // Handle +/-inf
- close.__ior__(self == other);
- close.__iand__((self == INFINITY) == (other == INFINITY));
- close.__iand__((self == -INFINITY) == (other == -INFINITY));
-
- if (equal_nan) {
- close.__ior__((self != self).__and__((other != other)));
- }
- }
return close;
}
@@ -87,8 +105,7 @@
// Note: a complex value is finite iff both parts are finite
if (self.is_complex()) {
- const auto float_type = c10::toValueType(self.scalar_type());
- return at::isfinite(self.abs().to(float_type));
+ return at::isfinite(self.abs());
}
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isfinite", [&]() {
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 12338ef..3e3f138 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -312,6 +312,7 @@
.. automethod:: int_repr
.. automethod:: inverse
.. automethod:: irfft
+ .. automethod:: isclose
.. automethod:: is_contiguous
.. automethod:: is_complex
.. automethod:: is_floating_point
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 8e1a4aa..6224fe2 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -284,6 +284,7 @@
.. autofunction:: equal
.. autofunction:: ge
.. autofunction:: gt
+.. autofunction:: isclose
.. autofunction:: isfinite
.. autofunction:: isinf
.. autofunction:: isnan
diff --git a/test/test_torch.py b/test/test_torch.py
index fef421b..36acd43 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -170,7 +170,6 @@
'is_distributed',
'is_nonzero',
'is_same_size',
- 'isclose',
'log_softmax',
'map2_',
'new',
@@ -193,25 +192,6 @@
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)
- def test_allclose(self):
- x = torch.tensor([1.0, 2.0, 3.0])
- y = torch.tensor([1.01, 2.01, 3.01])
- self.assertTrue(torch.allclose(x, y, rtol=0, atol=0.02))
- self.assertTrue(torch.allclose(x, y, rtol=0.01, atol=0.0))
- self.assertFalse(torch.allclose(x, y))
- self.assertTrue(torch.allclose(torch.tensor([0.0]), torch.tensor([1e-8])))
- x = torch.tensor([2.0, 3.0, nan])
- y = torch.tensor([2.01, 3.01, nan])
- self.assertFalse(torch.allclose(x, y, rtol=1e-2))
- self.assertTrue(torch.allclose(x, y, rtol=1e-2, equal_nan=True))
- self.assertFalse(torch.allclose(x, y, rtol=1e-3, equal_nan=True))
- inf_t = torch.tensor([inf])
- self.assertTrue(torch.allclose(inf_t, inf_t))
- self.assertTrue(torch.allclose(-inf_t, -inf_t))
- self.assertFalse(torch.allclose(inf_t, -inf_t))
- self.assertFalse(torch.allclose(inf_t, torch.tensor([1e20])))
- self.assertFalse(torch.allclose(-inf_t, torch.tensor([-1e20])))
-
def test_linear_algebra_scalar_raises(self):
m = torch.randn(5, 5)
v = torch.randn(5)
@@ -5308,6 +5288,197 @@
class TestTorchDeviceType(TestCase):
exact_dtype = True
+ def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
+ for test in tests:
+ a = torch.tensor((test[0],), device=device, dtype=dtype)
+ b = torch.tensor((test[1],), device=device, dtype=dtype)
+
+ actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
+ expected = test[2]
+ self.assertEqual(actual.item(), expected)
+
+ # torch.close is not implemented for bool tensors
+ # see https://github.com/pytorch/pytorch/issues/33048
+ def test_isclose_bool(self, device):
+ tests = (
+ (True, True, True),
+ (False, False, True),
+ (True, False, False),
+ (False, True, False),
+ )
+
+ with self.assertRaises(RuntimeError):
+ self._isclose_helper(tests, device, torch.bool, False)
+
+ @dtypes(torch.uint8,
+ torch.int8, torch.int16, torch.int32, torch.int64)
+ def test_isclose_integer(self, device, dtype):
+ tests = (
+ (0, 0, True),
+ (0, 1, False),
+ (1, 0, False),
+ )
+
+ self._isclose_helper(tests, device, dtype, False)
+
+ # atol and rtol tests
+ tests = [
+ (0, 1, True),
+ (1, 0, False),
+ (1, 3, True),
+ ]
+
+ self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+ if dtype is torch.uint8:
+ tests = [
+ (-1, 1, False),
+ (1, -1, False)
+ ]
+ else:
+ tests = [
+ (-1, 1, True),
+ (1, -1, True)
+ ]
+
+ self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
+
+ # torch.close is not implemented for cpu half tensors
+ # see https://github.com/pytorch/pytorch/issues/36451
+ @dtypes(torch.float16, torch.float32, torch.float64)
+ def test_isclose_float(self, device, dtype):
+ tests = (
+ (0, 0, True),
+ (0, -1, False),
+ (float('inf'), float('inf'), True),
+ (-float('inf'), float('inf'), False),
+ (float('inf'), float('nan'), False),
+ (float('nan'), float('nan'), False),
+ (0, float('nan'), False),
+ (1, 1, True),
+ )
+
+ if dtype is torch.half and self.device_type == 'cpu':
+ with self.assertRaises(RuntimeError):
+ self._isclose_helper(tests, device, dtype, False)
+ else:
+ self._isclose_helper(tests, device, dtype, False)
+
+ # atol and rtol tests
+ eps = 1e-2 if dtype is torch.half else 1e-6
+ tests = (
+ (0, 1, True),
+ (0, 1 + eps, False),
+ (1, 0, False),
+ (1, 3, True),
+ (1 - eps, 3, False),
+ (-.25, .5, True),
+ (-.25 - eps, .5, False),
+ (.25, -.5, True),
+ (.25 + eps, -.5, False),
+ )
+
+ if dtype is torch.half and self.device_type == 'cpu':
+ with self.assertRaises(RuntimeError):
+ self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+ else:
+ self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+ # equal_nan = True tests
+ tests = (
+ (0, float('nan'), False),
+ (float('inf'), float('nan'), False),
+ (float('nan'), float('nan'), True),
+ )
+
+ if dtype is torch.half and self.device_type == 'cpu':
+ with self.assertRaises(RuntimeError):
+ self._isclose_helper(tests, device, dtype, True)
+ else:
+ self._isclose_helper(tests, device, dtype, True)
+
+ # torch.close with equal_nan=True is not implemented for complex inputs
+ # see https://github.com/numpy/numpy/issues/15959
+ @dtypes(torch.complex64, torch.complex128)
+ def test_isclose_complex(self, device, dtype):
+ tests = (
+ (complex(1, 1), complex(1, 1 + 1e-8), True),
+ (complex(0, 1), complex(1, 1), False),
+ (complex(1, 1), complex(1, 0), False),
+ (complex(1, 1), complex(1, float('nan')), False),
+ (complex(1, float('nan')), complex(1, float('nan')), False),
+ (complex(1, 1), complex(1, float('inf')), False),
+ (complex(float('inf'), 1), complex(1, float('inf')), False),
+ (complex(-float('inf'), 1), complex(1, float('inf')), False),
+ (complex(-float('inf'), 1), complex(float('inf'), 1), False),
+ (complex(float('inf'), 1), complex(float('inf'), 1), True),
+ (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
+ )
+
+ self._isclose_helper(tests, device, dtype, False)
+
+ # atol and rtol tests
+
+ # atol and rtol tests
+ eps = 1e-6
+ tests = (
+ # Complex versions of float tests (real part)
+ (complex(0, 0), complex(1, 0), True),
+ (complex(0, 0), complex(1 + eps, 0), False),
+ (complex(1, 0), complex(0, 0), False),
+ (complex(1, 0), complex(3, 0), True),
+ (complex(1 - eps, 0), complex(3, 0), False),
+ (complex(-.25, 0), complex(.5, 0), True),
+ (complex(-.25 - eps, 0), complex(.5, 0), False),
+ (complex(.25, 0), complex(-.5, 0), True),
+ (complex(.25 + eps, 0), complex(-.5, 0), False),
+ # Complex versions of float tests (imaginary part)
+ (complex(0, 0), complex(0, 1), True),
+ (complex(0, 0), complex(0, 1 + eps), False),
+ (complex(0, 1), complex(0, 0), False),
+ (complex(0, 1), complex(0, 3), True),
+ (complex(0, 1 - eps), complex(0, 3), False),
+ (complex(0, -.25), complex(0, .5), True),
+ (complex(0, -.25 - eps), complex(0, .5), False),
+ (complex(0, .25), complex(0, -.5), True),
+ (complex(0, .25 + eps), complex(0, -.5), False),
+ # Complex-specific tests
+ (complex(1, -1), complex(-1, 1), False),
+ (complex(1, -1), complex(2, -2), True),
+ (complex(-math.sqrt(2), math.sqrt(2)),
+ complex(-math.sqrt(.5), math.sqrt(.5)), True),
+ (complex(-math.sqrt(2), math.sqrt(2)),
+ complex(-math.sqrt(.501), math.sqrt(.499)), False),
+ (complex(2, 4), complex(1., 8.8523607), True),
+ (complex(2, 4), complex(1., 8.8523607 + eps), False),
+ )
+
+ self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+ # equal_nan = True tests
+ tests = (
+ (complex(1, 1), complex(1, float('nan')), False),
+ (complex(float('nan'), 1), complex(1, float('nan')), False),
+ (complex(float('nan'), 1), complex(float('nan'), 1), True),
+ )
+
+ with self.assertRaises(RuntimeError):
+ self._isclose_helper(tests, device, dtype, True)
+
+ # Tests that rtol or atol values less than zero thow RuntimeErrors
+ @dtypes(torch.bool, torch.uint8,
+ torch.int8, torch.int16, torch.int32, torch.int64,
+ torch.float16, torch.float32, torch.float64)
+ def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
+ t = torch.tensor((1,), device=device, dtype=dtype)
+
+ with self.assertRaises(RuntimeError):
+ torch.isclose(t, t, atol=-1, rtol=1)
+ with self.assertRaises(RuntimeError):
+ torch.isclose(t, t, atol=1, rtol=-1)
+ with self.assertRaises(RuntimeError):
+ torch.isclose(t, t, atol=-1, rtol=-1)
+
def check_internal_mem_overlap(self, inplace_op, num_inputs,
dtype, device,
expected_failure=False):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 02b2920..0652ccd 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1535,6 +1535,13 @@
See :func:`torch.inverse`
""")
+add_docstr_all('isclose',
+ r"""
+isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
+
+See :func:`torch.isclose`
+""")
+
add_docstr_all('is_contiguous',
r"""
is_contiguous(memory_format=torch.contiguous_format) -> bool
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index c567aa3..29b7be2 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -483,7 +483,7 @@
other (Tensor): second tensor to compare
atol (float, optional): absolute tolerance. Default: 1e-08
rtol (float, optional): relative tolerance. Default: 1e-05
- equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be compared as equal. Default: ``False``
+ equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``
Example::
@@ -2605,6 +2605,38 @@
tensor([False, True, False, True, False])
""")
+add_docstr(torch.isclose,
+ r"""
+isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
+
+Returns a new tensor with boolean elements representing if each element of
+:attr:`input` is "close" to the corresponding element of :attr:`other`.
+Closeness is defined as:
+
+. math::
+ \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert
+""" + r"""
+
+where :attr:`input` and :attr:`other` are finite. Where :attr:`input`
+and/or :attr:`other` are nonfinite they are close if and only if
+they are equal, with NaNs being considered equal to each other when
+:attr:`equal_nan` is True.
+
+Args:
+ input (Tensor): first tensor to compare
+ other (Tensor): second tensor to compare
+ atol (float, optional): absolute tolerance. Default: 1e-08
+ rtol (float, optional): relative tolerance. Default: 1e-05
+ equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``
+
+Examples::
+
+ >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4)))
+ tensor([ True, False, False])
+ >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5)
+ tensor([True, True])
+""")
+
add_docstr(torch.isfinite,
r"""
Returns a new tensor with boolean elements representing if each element is `finite` or not.