Implements torch.isclose for complex tensors (#36456)

Summary:
Previously torch.isclose would RuntimeError when called on complex tensors. This update updates torch.isclose to run on complex tensors and be consistent with [NumPy](https://numpy.org/doc/1.18/reference/generated/numpy.isclose.html). However, NumPy's handling of NaN, -inf, and inf values is odd, so I adopted  Python's [cmath.isclose](https://docs.python.org/3/library/cmath.html) behavior when dealing with them. See https://github.com/numpy/numpy/issues/15959 for more on NumPy's behavior.

While implementing complex isclose I also simplified the isclose algorithm to:

- A is close to B if A and B are equal, if equal_nan is true then NaN is equal to NaN
- If A and B are finite, then A is close to B if `abs(a - b) <= (atol + abs(rtol * b))`

This PR also documents torch.isclose, since it was undocumented, and adds multiple tests for its behavior to test_torch.py since it had no dedicated tests.

The PR leaves equal_nan=True with complex inputs an error for now, pending the outcome of https://github.com/numpy/numpy/issues/15959.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36456

Differential Revision: D21159853

Pulled By: mruberry

fbshipit-source-id: fb18fa7048e6104cc24f5ce308fdfb0ba5e4bb30
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 9529b5a..c30c77d 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -18,42 +18,60 @@
   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
 }
 
+// Note [closeness]
+// A number A is close to B when either:
+//
+// (1) A is equal to B, with NaNs comparing equal when equal_nan is true.
+// (2) The error abs(A - B) is finite and less than the max error
+//      (atol + abs(rtol * B)).
+//
+// Note that this is consistent with NumPy's isclose but divergent from
+// Python's isclose, which computes the max error symmetrically as
+// max(rtol * max(abs(A), abs(B)), atol).
+// TODO: use bitwise operator overloads once we add them
+// TODO: revisit complex inputs and equal_nan=true after
+//  https://github.com/numpy/numpy/issues/15959 is resolved
 Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
-  // TODO: use bitwise operator overloads once we add them
+  TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
+  TORCH_CHECK(!(self.is_complex() && equal_nan),
+    "isclose with equal_nan=True is not supported for complex inputs.");
 
-  TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type())
+  // Checks that rtol and atol are non-negative
+  // Note: consistent with Python's isclose but divergent from NumPy's, which
+  //  allows negative atol and rtol.
+  TORCH_CHECK(rtol >= 0, "rtol must be greater than or equal to zero, but got ", rtol);
+  TORCH_CHECK(atol >= 0, "atol must be greater than or equal to zero, but got ", atol);
 
-  // The original formula `atol + rtol * other.abs()` works incorrectly when
-  // `other` has integral dtype and `other == min_value` and `abs(min_value)` is negative:
-  // std::abs(std::numeric_limits<int64_t>::lowest()) == std::numeric_limits<int64_t>::lowest() < 0
-  auto max_error = atol + (rtol * other).abs();
+  // Computes equality closeness
+  Tensor close = self == other;
+  if (equal_nan && self.is_floating_point()) {
+      close.__ior__((self != self).__iand__(other != other));
+  }
 
-  // `max_error` could be a float or double depending on the type of the input
-  // tensors.
-  // Specifically, if other is an int tensor, multiplying by rtol results in
-  // float tensor.
-  // It is also possible for parameters to be 'wrapped_number's, in which case
-  // max_error could be promoted to double when actual error is still a float.
-  Tensor actual_error;
-  if (actual_error.scalar_type() != max_error.scalar_type()) {
-    // To silence ASAN that does not like (x - std::numeric_limits<int64_t>::lowest())
-    actual_error = (self - other.to(max_error.scalar_type())).abs();
+  // Note [closeness error computation]
+  // atol and rtol are provided as doubles, so the computation
+  // rtol * other will produce a float or complex tensor.
+  // When the difference (self - other) is compared to it then the
+  // tensor representing the difference will also be cast to float or complex.
+  // However, since (self - other) in uint8 is very likely to produce a
+  // negative value, this moves the cast forward so the difference is
+  // always computed in a float or complex type.
+  // If the values of the integer tensors cannot be exactly represented
+  // by the default scalar type then this may cause an incorrect result.
+
+  // Computes allowed and actual error
+  Tensor cast_other;
+  if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+    cast_other = other.to(at::get_default_dtype());
   } else {
-    actual_error = (self - other).abs();
+    cast_other = other;
   }
+  Tensor allowed_error = atol + (rtol * cast_other).abs();
+  Tensor actual_error = (self - cast_other).abs();
 
-  auto close = actual_error <= max_error;
+  // Computes finite closeness
+  close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
 
-  if (isFloatingType(self.scalar_type()) && isFloatingType(other.scalar_type())) {
-    // Handle +/-inf
-    close.__ior__(self == other);
-    close.__iand__((self == INFINITY) == (other == INFINITY));
-    close.__iand__((self == -INFINITY) == (other == -INFINITY));
-
-    if (equal_nan) {
-      close.__ior__((self != self).__and__((other != other)));
-    }
-  }
   return close;
 }
 
@@ -87,8 +105,7 @@
 
   // Note: a complex value is finite iff both parts are finite
   if (self.is_complex()) {
-    const auto float_type = c10::toValueType(self.scalar_type());
-    return at::isfinite(self.abs().to(float_type));
+    return at::isfinite(self.abs());
   }
 
   return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isfinite", [&]() {
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 12338ef..3e3f138 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -312,6 +312,7 @@
    .. automethod:: int_repr
    .. automethod:: inverse
    .. automethod:: irfft
+   .. automethod:: isclose
    .. automethod:: is_contiguous
    .. automethod:: is_complex
    .. automethod:: is_floating_point
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 8e1a4aa..6224fe2 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -284,6 +284,7 @@
 .. autofunction:: equal
 .. autofunction:: ge
 .. autofunction:: gt
+.. autofunction:: isclose
 .. autofunction:: isfinite
 .. autofunction:: isinf
 .. autofunction:: isnan
diff --git a/test/test_torch.py b/test/test_torch.py
index fef421b..36acd43 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -170,7 +170,6 @@
                        'is_distributed',
                        'is_nonzero',
                        'is_same_size',
-                       'isclose',
                        'log_softmax',
                        'map2_',
                        'new',
@@ -193,25 +192,6 @@
         # TODO: add torch.* tests when we have proper namespacing on ATen functions
         # test_namespace(torch)
 
-    def test_allclose(self):
-        x = torch.tensor([1.0, 2.0, 3.0])
-        y = torch.tensor([1.01, 2.01, 3.01])
-        self.assertTrue(torch.allclose(x, y, rtol=0, atol=0.02))
-        self.assertTrue(torch.allclose(x, y, rtol=0.01, atol=0.0))
-        self.assertFalse(torch.allclose(x, y))
-        self.assertTrue(torch.allclose(torch.tensor([0.0]), torch.tensor([1e-8])))
-        x = torch.tensor([2.0, 3.0, nan])
-        y = torch.tensor([2.01, 3.01, nan])
-        self.assertFalse(torch.allclose(x, y, rtol=1e-2))
-        self.assertTrue(torch.allclose(x, y, rtol=1e-2, equal_nan=True))
-        self.assertFalse(torch.allclose(x, y, rtol=1e-3, equal_nan=True))
-        inf_t = torch.tensor([inf])
-        self.assertTrue(torch.allclose(inf_t, inf_t))
-        self.assertTrue(torch.allclose(-inf_t, -inf_t))
-        self.assertFalse(torch.allclose(inf_t, -inf_t))
-        self.assertFalse(torch.allclose(inf_t, torch.tensor([1e20])))
-        self.assertFalse(torch.allclose(-inf_t, torch.tensor([-1e20])))
-
     def test_linear_algebra_scalar_raises(self):
         m = torch.randn(5, 5)
         v = torch.randn(5)
@@ -5308,6 +5288,197 @@
 class TestTorchDeviceType(TestCase):
     exact_dtype = True
 
+    def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
+        for test in tests:
+            a = torch.tensor((test[0],), device=device, dtype=dtype)
+            b = torch.tensor((test[1],), device=device, dtype=dtype)
+
+            actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
+            expected = test[2]
+            self.assertEqual(actual.item(), expected)
+
+    # torch.close is not implemented for bool tensors
+    # see https://github.com/pytorch/pytorch/issues/33048
+    def test_isclose_bool(self, device):
+        tests = (
+            (True, True, True),
+            (False, False, True),
+            (True, False, False),
+            (False, True, False),
+        )
+
+        with self.assertRaises(RuntimeError):
+            self._isclose_helper(tests, device, torch.bool, False)
+
+    @dtypes(torch.uint8,
+            torch.int8, torch.int16, torch.int32, torch.int64)
+    def test_isclose_integer(self, device, dtype):
+        tests = (
+            (0, 0, True),
+            (0, 1, False),
+            (1, 0, False),
+        )
+
+        self._isclose_helper(tests, device, dtype, False)
+
+        # atol and rtol tests
+        tests = [
+            (0, 1, True),
+            (1, 0, False),
+            (1, 3, True),
+        ]
+
+        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+        if dtype is torch.uint8:
+            tests = [
+                (-1, 1, False),
+                (1, -1, False)
+            ]
+        else:
+            tests = [
+                (-1, 1, True),
+                (1, -1, True)
+            ]
+
+        self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
+
+    # torch.close is not implemented for cpu half tensors
+    # see https://github.com/pytorch/pytorch/issues/36451
+    @dtypes(torch.float16, torch.float32, torch.float64)
+    def test_isclose_float(self, device, dtype):
+        tests = (
+            (0, 0, True),
+            (0, -1, False),
+            (float('inf'), float('inf'), True),
+            (-float('inf'), float('inf'), False),
+            (float('inf'), float('nan'), False),
+            (float('nan'), float('nan'), False),
+            (0, float('nan'), False),
+            (1, 1, True),
+        )
+
+        if dtype is torch.half and self.device_type == 'cpu':
+            with self.assertRaises(RuntimeError):
+                self._isclose_helper(tests, device, dtype, False)
+        else:
+            self._isclose_helper(tests, device, dtype, False)
+
+        # atol and rtol tests
+        eps = 1e-2 if dtype is torch.half else 1e-6
+        tests = (
+            (0, 1, True),
+            (0, 1 + eps, False),
+            (1, 0, False),
+            (1, 3, True),
+            (1 - eps, 3, False),
+            (-.25, .5, True),
+            (-.25 - eps, .5, False),
+            (.25, -.5, True),
+            (.25 + eps, -.5, False),
+        )
+
+        if dtype is torch.half and self.device_type == 'cpu':
+            with self.assertRaises(RuntimeError):
+                self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+        else:
+            self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+        # equal_nan = True tests
+        tests = (
+            (0, float('nan'), False),
+            (float('inf'), float('nan'), False),
+            (float('nan'), float('nan'), True),
+        )
+
+        if dtype is torch.half and self.device_type == 'cpu':
+            with self.assertRaises(RuntimeError):
+                self._isclose_helper(tests, device, dtype, True)
+        else:
+            self._isclose_helper(tests, device, dtype, True)
+
+    # torch.close with equal_nan=True is not implemented for complex inputs
+    # see https://github.com/numpy/numpy/issues/15959
+    @dtypes(torch.complex64, torch.complex128)
+    def test_isclose_complex(self, device, dtype):
+        tests = (
+            (complex(1, 1), complex(1, 1 + 1e-8), True),
+            (complex(0, 1), complex(1, 1), False),
+            (complex(1, 1), complex(1, 0), False),
+            (complex(1, 1), complex(1, float('nan')), False),
+            (complex(1, float('nan')), complex(1, float('nan')), False),
+            (complex(1, 1), complex(1, float('inf')), False),
+            (complex(float('inf'), 1), complex(1, float('inf')), False),
+            (complex(-float('inf'), 1), complex(1, float('inf')), False),
+            (complex(-float('inf'), 1), complex(float('inf'), 1), False),
+            (complex(float('inf'), 1), complex(float('inf'), 1), True),
+            (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
+        )
+
+        self._isclose_helper(tests, device, dtype, False)
+
+        # atol and rtol tests
+
+        # atol and rtol tests
+        eps = 1e-6
+        tests = (
+            # Complex versions of float tests (real part)
+            (complex(0, 0), complex(1, 0), True),
+            (complex(0, 0), complex(1 + eps, 0), False),
+            (complex(1, 0), complex(0, 0), False),
+            (complex(1, 0), complex(3, 0), True),
+            (complex(1 - eps, 0), complex(3, 0), False),
+            (complex(-.25, 0), complex(.5, 0), True),
+            (complex(-.25 - eps, 0), complex(.5, 0), False),
+            (complex(.25, 0), complex(-.5, 0), True),
+            (complex(.25 + eps, 0), complex(-.5, 0), False),
+            # Complex versions of float tests (imaginary part)
+            (complex(0, 0), complex(0, 1), True),
+            (complex(0, 0), complex(0, 1 + eps), False),
+            (complex(0, 1), complex(0, 0), False),
+            (complex(0, 1), complex(0, 3), True),
+            (complex(0, 1 - eps), complex(0, 3), False),
+            (complex(0, -.25), complex(0, .5), True),
+            (complex(0, -.25 - eps), complex(0, .5), False),
+            (complex(0, .25), complex(0, -.5), True),
+            (complex(0, .25 + eps), complex(0, -.5), False),
+            # Complex-specific tests
+            (complex(1, -1), complex(-1, 1), False),
+            (complex(1, -1), complex(2, -2), True),
+            (complex(-math.sqrt(2), math.sqrt(2)),
+             complex(-math.sqrt(.5), math.sqrt(.5)), True),
+            (complex(-math.sqrt(2), math.sqrt(2)),
+             complex(-math.sqrt(.501), math.sqrt(.499)), False),
+            (complex(2, 4), complex(1., 8.8523607), True),
+            (complex(2, 4), complex(1., 8.8523607 + eps), False),
+        )
+
+        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
+
+        # equal_nan = True tests
+        tests = (
+            (complex(1, 1), complex(1, float('nan')), False),
+            (complex(float('nan'), 1), complex(1, float('nan')), False),
+            (complex(float('nan'), 1), complex(float('nan'), 1), True),
+        )
+
+        with self.assertRaises(RuntimeError):
+            self._isclose_helper(tests, device, dtype, True)
+
+    # Tests that rtol or atol values less than zero thow RuntimeErrors
+    @dtypes(torch.bool, torch.uint8,
+            torch.int8, torch.int16, torch.int32, torch.int64,
+            torch.float16, torch.float32, torch.float64)
+    def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
+        t = torch.tensor((1,), device=device, dtype=dtype)
+
+        with self.assertRaises(RuntimeError):
+            torch.isclose(t, t, atol=-1, rtol=1)
+        with self.assertRaises(RuntimeError):
+            torch.isclose(t, t, atol=1, rtol=-1)
+        with self.assertRaises(RuntimeError):
+            torch.isclose(t, t, atol=-1, rtol=-1)
+
     def check_internal_mem_overlap(self, inplace_op, num_inputs,
                                    dtype, device,
                                    expected_failure=False):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 02b2920..0652ccd 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1535,6 +1535,13 @@
 See :func:`torch.inverse`
 """)
 
+add_docstr_all('isclose',
+               r"""
+isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
+
+See :func:`torch.isclose`
+""")
+
 add_docstr_all('is_contiguous',
                r"""
 is_contiguous(memory_format=torch.contiguous_format) -> bool
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index c567aa3..29b7be2 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -483,7 +483,7 @@
     other (Tensor): second tensor to compare
     atol (float, optional): absolute tolerance. Default: 1e-08
     rtol (float, optional): relative tolerance. Default: 1e-05
-    equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be compared as equal. Default: ``False``
+    equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``
 
 Example::
 
@@ -2605,6 +2605,38 @@
         tensor([False,  True,  False,  True,  False])
 """)
 
+add_docstr(torch.isclose,
+           r"""
+isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
+
+Returns a new tensor with boolean elements representing if each element of
+:attr:`input` is "close" to the corresponding element of :attr:`other`.
+Closeness is defined as:
+
+. math::
+    \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert
+""" + r"""
+
+where :attr:`input` and :attr:`other` are finite. Where :attr:`input`
+and/or :attr:`other` are nonfinite they are close if and only if
+they are equal, with NaNs being considered equal to each other when
+:attr:`equal_nan` is True.
+
+Args:
+    input (Tensor): first tensor to compare
+    other (Tensor): second tensor to compare
+    atol (float, optional): absolute tolerance. Default: 1e-08
+    rtol (float, optional): relative tolerance. Default: 1e-05
+    equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``
+
+Examples::
+
+    >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4)))
+    tensor([ True, False, False])
+    >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5)
+    tensor([True, True])
+""")
+
 add_docstr(torch.isfinite,
            r"""
 Returns a new tensor with boolean elements representing if each element is `finite` or not.