add support for sparse tensors in `torch.testing.assert_close` (#58844)

Summary:
This adds support for sparse tensors the same way `torch.testing._internal.common_utils.TestCase.assertEqual` does:

https://github.com/pytorch/pytorch/blob/5c7dace309cc84cc17629172ce97566285a9da58/torch/testing/_internal/common_utils.py#L1287-L1313

- Tensors are coalesced before comparison.
- Indices and values are compared individually.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58844

Reviewed By: zou3519

Differential Revision: D29160250

Pulled By: mruberry

fbshipit-source-id: b0955656c2c7ff3db37a1367427ca54ca14f2e87
diff --git a/test/test_testing.py b/test/test_testing.py
index cf87e0b..a96c7fd 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -11,7 +11,7 @@
 import torch
 
 from torch.testing._internal.common_utils import \
-    (IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
+    (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
 from torch.testing._internal.framework_utils import calculate_shards
 from torch.testing._internal.common_device_type import \
     (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
@@ -817,14 +817,6 @@
 
 
 class TestAssertClose(TestCase):
-    def test_sparse_support(self):
-        actual = torch.empty(())
-        expected = torch.sparse_coo_tensor(size=())
-
-        for fn in assert_close_with_inputs(actual, expected):
-            with self.assertRaises(UsageError):
-                fn()
-
     def test_quantized_support(self):
         val = 1
         actual = torch.tensor([val], dtype=torch.int32)
@@ -859,6 +851,25 @@
             with self.assertRaisesRegex(AssertionError, "shape"):
                 fn()
 
+    @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
+    def test_unknown_layout(self):
+        actual = torch.empty((2, 2))
+        expected = actual.to_mkldnn()
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaises(UsageError):
+                fn()
+
+    def test_mismatching_layout(self):
+        strided = torch.empty((2, 2))
+        sparse_coo = strided.to_sparse()
+        sparse_csr = strided.to_sparse_csr()
+
+        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
+            for fn in assert_close_with_inputs(actual, expected):
+                with self.assertRaisesRegex(AssertionError, "layout"):
+                    fn()
+
     def test_mismatching_dtype(self):
         actual = torch.empty((), dtype=torch.float)
         expected = actual.clone().to(torch.int)
@@ -1158,5 +1169,180 @@
                 fn()
 
 
+class TestAssertCloseSparseCOO(TestCase):
+    def test_matching_coalesced(self):
+        indices = (
+            (0, 1),
+            (1, 0),
+        )
+        values = (1, 2)
+        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
+        expected = actual.clone()
+
+        for fn in assert_close_with_inputs(actual, expected):
+            fn()
+
+    def test_matching_uncoalesced(self):
+        indices = (
+            (0, 1),
+            (1, 0),
+        )
+        values = (1, 2)
+        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
+        expected = actual.clone()
+
+        for fn in assert_close_with_inputs(actual, expected):
+            fn()
+
+    def test_mismatching_is_coalesced(self):
+        indices = (
+            (0, 1),
+            (1, 0),
+        )
+        values = (1, 2)
+        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
+        expected = actual.clone().coalesce()
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, "is_coalesced"):
+                fn()
+
+    def test_mismatching_is_coalesced_no_check(self):
+        actual_indices = (
+            (0, 1),
+            (1, 0),
+        )
+        actual_values = (1, 2)
+        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
+
+        expected_indices = (
+            (0, 1, 1,),
+            (1, 0, 0,),
+        )
+        expected_values = (1, 1, 1)
+        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            fn(check_is_coalesced=False)
+
+    def test_mismatching_nnz(self):
+        actual_indices = (
+            (0, 1),
+            (1, 0),
+        )
+        actual_values = (1, 2)
+        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+        expected_indices = (
+            (0, 1, 1,),
+            (1, 0, 0,),
+        )
+        expected_values = (1, 1, 1)
+        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("number of specified values")):
+                fn()
+
+    def test_mismatching_indices_msg(self):
+        actual_indices = (
+            (0, 1),
+            (1, 0),
+        )
+        actual_values = (1, 2)
+        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+        expected_indices = (
+            (0, 1),
+            (1, 1),
+        )
+        expected_values = (1, 2)
+        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the indices")):
+                fn()
+
+    def test_mismatching_values_msg(self):
+        actual_indices = (
+            (0, 1),
+            (1, 0),
+        )
+        actual_values = (1, 2)
+        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+        expected_indices = (
+            (0, 1),
+            (1, 0),
+        )
+        expected_values = (1, 3)
+        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
+                fn()
+
+
+@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
+class TestAssertCloseSparseCSR(TestCase):
+    def test_matching(self):
+        crow_indices = (0, 1, 2)
+        col_indices = (1, 0)
+        values = (1, 2)
+        actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
+        # TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
+        expected = torch.sparse_csr_tensor(
+            actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
+        )
+
+        for fn in assert_close_with_inputs(actual, expected):
+            fn()
+
+    def test_mismatching_crow_indices_msg(self):
+        actual_crow_indices = (0, 1, 2)
+        actual_col_indices = (1, 0)
+        actual_values = (1, 2)
+        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+        expected_crow_indices = (0, 2, 2)
+        expected_col_indices = actual_col_indices
+        expected_values = actual_values
+        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the crow_indices")):
+                fn()
+
+    def test_mismatching_col_indices_msg(self):
+        actual_crow_indices = (0, 1, 2)
+        actual_col_indices = (1, 0)
+        actual_values = (1, 2)
+        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+        expected_crow_indices = actual_crow_indices
+        expected_col_indices = (1, 1)
+        expected_values = actual_values
+        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the col_indices")):
+                fn()
+
+    def test_mismatching_values_msg(self):
+        actual_crow_indices = (0, 1, 2)
+        actual_col_indices = (1, 0)
+        actual_values = (1, 2)
+        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+        expected_crow_indices = actual_crow_indices
+        expected_col_indices = actual_col_indices
+        expected_values = (1, 3)
+        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
+                fn()
+
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/testing/_asserts.py b/torch/testing/_asserts.py
index 5ff73c3..349988d 100644
--- a/torch/testing/_asserts.py
+++ b/torch/testing/_asserts.py
@@ -77,6 +77,7 @@
 
         if actual.dtype not in (torch.complex32, torch.complex64, torch.complex128):
             return check_tensors(actual, expected, equal_nan=equal_nan, **kwargs)
+
         if relaxed_complex_nan:
             actual, expected = [
                 t.clone().masked_fill(
@@ -98,18 +99,87 @@
     return wrapper
 
 
-def _check_supported_tensor(input: Tensor) -> Optional[_TestingErrorMeta]:
-    """Checks if the tensors are supported by the current infrastructure.
+def _check_sparse_coo_members_individually(
+    check_tensors: Callable[..., Optional[_TestingErrorMeta]]
+) -> Callable[..., Optional[_TestingErrorMeta]]:
+    """Decorates strided tensor check functions to individually handle sparse COO members.
 
-    All checks are temporary and will be relaxed in the future.
+    If the inputs are not sparse COO, this decorator is a no-op.
+
+    Args:
+        check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Tensor check function for strided tensors.
+    """
+
+    @functools.wraps(check_tensors)
+    def wrapper(actual: Tensor, expected: Tensor, **kwargs: Any) -> Optional[_TestingErrorMeta]:
+        if not actual.is_sparse:
+            return check_tensors(actual, expected, **kwargs)
+
+        if actual._nnz() != expected._nnz():
+            return _TestingErrorMeta(
+                AssertionError, f"The number of specified values does not match: {actual._nnz()} != {expected._nnz()}"
+            )
+
+        kwargs_equal = dict(kwargs, rtol=0, atol=0)
+        error_meta = check_tensors(actual._indices(), expected._indices(), **kwargs_equal)
+        if error_meta:
+            return error_meta.amend_msg(postfix="\n\nThe failure occurred for the indices.")
+
+        error_meta = check_tensors(actual._values(), expected._values(), **kwargs)
+        if error_meta:
+            return error_meta.amend_msg(postfix="\n\nThe failure occurred for the values.")
+
+        return None
+
+    return wrapper
+
+
+def _check_sparse_csr_members_individually(
+    check_tensors: Callable[..., Optional[_TestingErrorMeta]]
+) -> Callable[..., Optional[_TestingErrorMeta]]:
+    """Decorates strided tensor check functions to individually handle sparse CSR members.
+
+    If the inputs are not sparse CSR, this decorator is a no-op.
+
+    Args:
+        check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Tensor check function for strided
+        tensors.
+    """
+
+    @functools.wraps(check_tensors)
+    def wrapper(actual: Tensor, expected: Tensor, **kwargs: Any) -> Optional[_TestingErrorMeta]:
+        if not actual.is_sparse_csr:
+            return check_tensors(actual, expected, **kwargs)
+
+        kwargs_equal = dict(kwargs, rtol=0, atol=0)
+        error_meta = check_tensors(actual.crow_indices(), expected.crow_indices(), **kwargs_equal)
+        if error_meta:
+            return error_meta.amend_msg(postfix="\n\nThe failure occurred for the crow_indices.")
+
+        error_meta = check_tensors(actual.col_indices(), expected.col_indices(), **kwargs_equal)
+        if error_meta:
+            return error_meta.amend_msg(postfix="\n\nThe failure occurred for the col_indices.")
+
+        error_meta = check_tensors(actual.values(), expected.values(), **kwargs)
+        if error_meta:
+            return error_meta.amend_msg(postfix="\n\nThe failure occurred for the values.")
+
+        return None
+
+    return wrapper
+
+
+def _check_supported_tensor(input: Tensor) -> Optional[_TestingErrorMeta]:
+    """Checks if the tensor is supported by the current infrastructure.
 
     Returns:
         (Optional[_TestingErrorMeta]): If check did not pass.
     """
     if input.is_quantized:
         return _TestingErrorMeta(UsageError, "Comparison for quantized tensors is not supported yet.")
-    if input.is_sparse:
-        return _TestingErrorMeta(UsageError, "Comparison for sparse tensors is not supported yet.")
+
+    if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:  # type: ignore[attr-defined]
+        return _TestingErrorMeta(UsageError, f"Unsupported tensor layout {input.layout}")
 
     return None
 
@@ -121,11 +191,13 @@
     check_device: bool = True,
     check_dtype: bool = True,
     check_stride: bool = True,
+    check_is_coalesced: bool = True,
 ) -> Optional[_TestingErrorMeta]:
     """Checks if the attributes of two tensors match.
 
-    Always checks the :attr:`~torch.Tensor.shape`. Checks for :attr:`~torch.Tensor.device`,
-    :attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
+    Always checks the :attr:`~torch.Tensor.shape` and :attr:`~torch.Tensor.layout`. Checks for
+    :attr:`~torch.Tensor.device`, :attr:`~torch.Tensor.dtype`, :meth:`~torch.Tensor.stride` if the tensors are strided,
+    and :meth:`~torch.tensor.is_coalesced` if the tensors are sparse COO are optional and can be disabled.
 
     Args:
         actual (Tensor): Actual tensor.
@@ -134,8 +206,10 @@
             same :attr:`~torch.Tensor.device`.
         check_dtype (bool): If ``True`` (default), checks that both :attr:`actual` and :attr:`expected` have the same
             ``dtype``.
-        check_stride (bool): If ``True`` (default), checks that both :attr:`actual` and :attr:`expected` have the same
-            stride.
+        check_stride (bool): If ``True`` (default) and the tensors are strided, checks that both :attr:`actual` and
+            :attr:`expected` have the same stride.
+        check_is_coalesced (bool): If ``True`` (default) and the tensors are sparse COO, checks that both
+            :attr:`actual` and :attr:`expected` are either coalesced or uncoalesced.
 
     Returns:
         (Optional[_TestingErrorMeta]): If checks did not pass.
@@ -145,15 +219,21 @@
     if actual.shape != expected.shape:
         return _TestingErrorMeta(AssertionError, msg_fmtstr.format("shape", actual.shape, expected.shape))
 
+    if actual.layout != expected.layout:
+        return _TestingErrorMeta(AssertionError, msg_fmtstr.format("layout", actual.layout, expected.layout))
+    elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride():
+        return _TestingErrorMeta(AssertionError, msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
+    elif actual.layout == torch.sparse_coo and check_is_coalesced and actual.is_coalesced() != expected.is_coalesced():
+        return _TestingErrorMeta(
+            AssertionError, msg_fmtstr.format("is_coalesced()", actual.is_coalesced(), expected.is_coalesced())
+        )
+
     if check_device and actual.device != expected.device:
         return _TestingErrorMeta(AssertionError, msg_fmtstr.format("device", actual.device, expected.device))
 
     if check_dtype and actual.dtype != expected.dtype:
         return _TestingErrorMeta(AssertionError, msg_fmtstr.format("dtype", actual.dtype, expected.dtype))
 
-    if check_stride and actual.stride() != expected.stride():
-        return _TestingErrorMeta(AssertionError, msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
-
     return None
 
 
@@ -181,6 +261,10 @@
         actual = actual.to(dtype)
         expected = expected.to(dtype)
 
+    if actual.is_sparse and actual.is_coalesced() != expected.is_coalesced():
+        actual = actual.coalesce()
+        expected = expected.coalesce()
+
     return actual, expected
 
 
@@ -239,6 +323,8 @@
     )
 
 
+@_check_sparse_coo_members_individually
+@_check_sparse_csr_members_individually
 @_check_complex_components_individually
 def _check_values_close(
     actual: Tensor,
@@ -292,6 +378,7 @@
     check_device: bool = True,
     check_dtype: bool = True,
     check_stride: bool = True,
+    check_is_coalesced: bool = True,
     msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
 ) -> Optional[_TestingErrorMeta]:
     r"""Checks that the values of :attr:`actual` and :attr:`expected` are close.
@@ -323,7 +410,12 @@
         rtol, atol = _get_default_rtol_and_atol(actual, expected)
 
     error_meta = _check_attributes_equal(
-        actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
+        actual,
+        expected,
+        check_device=check_device,
+        check_dtype=check_dtype,
+        check_stride=check_stride,
+        check_is_coalesced=check_is_coalesced,
     )
     if error_meta:
         return error_meta
@@ -537,11 +629,12 @@
     check_device: bool = True,
     check_dtype: bool = True,
     check_stride: bool = True,
+    check_is_coalesced: bool = True,
     msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
 ) -> None:
     r"""Asserts that :attr:`actual` and :attr:`expected` are close.
 
-    If :attr:`actual` and :attr:`expected` are real-valued and finite, they are considered close if
+    If :attr:`actual` and :attr:`expected` are strided, real-valued, and finite, they are considered close if
 
     .. math::
 
@@ -555,6 +648,12 @@
     If :attr:`actual` and :attr:`expected` are complex-valued, they are considered close if both their real and
     imaginary components are considered close according to the definition above.
 
+    If :attr:`actual` and :attr:`expected` are sparse (either having COO or CSR layout), their strided members are
+    checked individually. Indices, namely ``indices`` for COO or ``crow_indices``  and ``col_indices`` for CSR layout,
+    are always checked for equality whereas the values are checked for closeness according to the definition above.
+    Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if
+    :attr:`check_is_coalesced` is ``True``).
+
     :attr:`actual` and :attr:`expected` can be :class:`~torch.Tensor`'s or any array-or-scalar-like of the same type,
     from which :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. In addition, :attr:`actual` and
     :attr:`expected` can be :class:`~collections.abc.Sequence`'s or :class:`~collections.abc.Mapping`'s in which case
@@ -576,24 +675,30 @@
         check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
             check is disabled, tensors with different ``dtype``'s are promoted  to a common ``dtype`` (according to
             :func:`torch.promote_types`) before being compared.
-        check_stride (bool): If ``True`` (default), asserts that corresponding tensors have the same stride.
+        check_stride (bool): If ``True`` (default) and corresponding tensors are strided, asserts that they have the
+            same stride.
+        check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both
+            :attr:`actual` and :attr:`expected` are either coalesced or uncoalesced. If this check is disabled,
+            tensors are :meth:`~torch.Tensor.coalesce`'ed before being compared.
         msg (Optional[Union[str, Callable[[Tensor, Tensor, DiagnosticInfo], str]]]): Optional error message to use if
             the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called
             with the mismatching tensors and a namespace of diagnostic info about the mismatches. See below for details.
 
     Raises:
         UsageError: If a :class:`torch.Tensor` can't be constructed from an array-or-scalar-like.
-        UsageError: If any tensor is quantized or sparse. This is a temporary restriction and will be relaxed in the
-            future.
+        UsageError: If any tensor is quantized. This is a temporary restriction and will be relaxed in the future.
         UsageError: If only :attr:`rtol` or :attr:`atol` is specified.
         AssertionError: If corresponding array-likes have different types.
         AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
         AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
         AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
+        AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.layout`.
         AssertionError: If :attr:`check_device`, but corresponding tensors are not on the same
             :attr:`~torch.Tensor.device`.
         AssertionError: If :attr:`check_dtype`, but corresponding tensors do not have the same ``dtype``.
-        AssertionError: If :attr:`check_stride`, but corresponding tensors do not have the same stride.
+        AssertionError: If :attr:`check_stride`, but corresponding strided tensors do not have the same stride.
+        AssertionError: If :attr:`check_is_coalesced`, but corresponding sparse COO tensors are not both either
+            coalesced or uncoalesced.
         AssertionError: If the values of corresponding tensors are not close.
 
     The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. Note that the ``dtype``
@@ -743,6 +848,7 @@
         check_device=check_device,
         check_dtype=check_dtype,
         check_stride=check_stride,
+        check_is_coalesced=check_is_coalesced,
         msg=msg,
     )
     if error_meta: