add support for sparse tensors in `torch.testing.assert_close` (#58844)
Summary:
This adds support for sparse tensors the same way `torch.testing._internal.common_utils.TestCase.assertEqual` does:
https://github.com/pytorch/pytorch/blob/5c7dace309cc84cc17629172ce97566285a9da58/torch/testing/_internal/common_utils.py#L1287-L1313
- Tensors are coalesced before comparison.
- Indices and values are compared individually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58844
Reviewed By: zou3519
Differential Revision: D29160250
Pulled By: mruberry
fbshipit-source-id: b0955656c2c7ff3db37a1367427ca54ca14f2e87
diff --git a/test/test_testing.py b/test/test_testing.py
index cf87e0b..a96c7fd 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -11,7 +11,7 @@
import torch
from torch.testing._internal.common_utils import \
- (IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
+ (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
from torch.testing._internal.framework_utils import calculate_shards
from torch.testing._internal.common_device_type import \
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
@@ -817,14 +817,6 @@
class TestAssertClose(TestCase):
- def test_sparse_support(self):
- actual = torch.empty(())
- expected = torch.sparse_coo_tensor(size=())
-
- for fn in assert_close_with_inputs(actual, expected):
- with self.assertRaises(UsageError):
- fn()
-
def test_quantized_support(self):
val = 1
actual = torch.tensor([val], dtype=torch.int32)
@@ -859,6 +851,25 @@
with self.assertRaisesRegex(AssertionError, "shape"):
fn()
+ @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
+ def test_unknown_layout(self):
+ actual = torch.empty((2, 2))
+ expected = actual.to_mkldnn()
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaises(UsageError):
+ fn()
+
+ def test_mismatching_layout(self):
+ strided = torch.empty((2, 2))
+ sparse_coo = strided.to_sparse()
+ sparse_csr = strided.to_sparse_csr()
+
+ for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, "layout"):
+ fn()
+
def test_mismatching_dtype(self):
actual = torch.empty((), dtype=torch.float)
expected = actual.clone().to(torch.int)
@@ -1158,5 +1169,180 @@
fn()
+class TestAssertCloseSparseCOO(TestCase):
+ def test_matching_coalesced(self):
+ indices = (
+ (0, 1),
+ (1, 0),
+ )
+ values = (1, 2)
+ actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
+ expected = actual.clone()
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn()
+
+ def test_matching_uncoalesced(self):
+ indices = (
+ (0, 1),
+ (1, 0),
+ )
+ values = (1, 2)
+ actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
+ expected = actual.clone()
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn()
+
+ def test_mismatching_is_coalesced(self):
+ indices = (
+ (0, 1),
+ (1, 0),
+ )
+ values = (1, 2)
+ actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
+ expected = actual.clone().coalesce()
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, "is_coalesced"):
+ fn()
+
+ def test_mismatching_is_coalesced_no_check(self):
+ actual_indices = (
+ (0, 1),
+ (1, 0),
+ )
+ actual_values = (1, 2)
+ actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
+
+ expected_indices = (
+ (0, 1, 1,),
+ (1, 0, 0,),
+ )
+ expected_values = (1, 1, 1)
+ expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn(check_is_coalesced=False)
+
+ def test_mismatching_nnz(self):
+ actual_indices = (
+ (0, 1),
+ (1, 0),
+ )
+ actual_values = (1, 2)
+ actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+ expected_indices = (
+ (0, 1, 1,),
+ (1, 0, 0,),
+ )
+ expected_values = (1, 1, 1)
+ expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("number of specified values")):
+ fn()
+
+ def test_mismatching_indices_msg(self):
+ actual_indices = (
+ (0, 1),
+ (1, 0),
+ )
+ actual_values = (1, 2)
+ actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+ expected_indices = (
+ (0, 1),
+ (1, 1),
+ )
+ expected_values = (1, 2)
+ expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the indices")):
+ fn()
+
+ def test_mismatching_values_msg(self):
+ actual_indices = (
+ (0, 1),
+ (1, 0),
+ )
+ actual_values = (1, 2)
+ actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
+
+ expected_indices = (
+ (0, 1),
+ (1, 0),
+ )
+ expected_values = (1, 3)
+ expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
+ fn()
+
+
+@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
+class TestAssertCloseSparseCSR(TestCase):
+ def test_matching(self):
+ crow_indices = (0, 1, 2)
+ col_indices = (1, 0)
+ values = (1, 2)
+ actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
+ # TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
+ expected = torch.sparse_csr_tensor(
+ actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
+ )
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn()
+
+ def test_mismatching_crow_indices_msg(self):
+ actual_crow_indices = (0, 1, 2)
+ actual_col_indices = (1, 0)
+ actual_values = (1, 2)
+ actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+ expected_crow_indices = (0, 2, 2)
+ expected_col_indices = actual_col_indices
+ expected_values = actual_values
+ expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the crow_indices")):
+ fn()
+
+ def test_mismatching_col_indices_msg(self):
+ actual_crow_indices = (0, 1, 2)
+ actual_col_indices = (1, 0)
+ actual_values = (1, 2)
+ actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+ expected_crow_indices = actual_crow_indices
+ expected_col_indices = (1, 1)
+ expected_values = actual_values
+ expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the col_indices")):
+ fn()
+
+ def test_mismatching_values_msg(self):
+ actual_crow_indices = (0, 1, 2)
+ actual_col_indices = (1, 0)
+ actual_values = (1, 2)
+ actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
+
+ expected_crow_indices = actual_crow_indices
+ expected_col_indices = actual_col_indices
+ expected_values = (1, 3)
+ expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
+ fn()
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/testing/_asserts.py b/torch/testing/_asserts.py
index 5ff73c3..349988d 100644
--- a/torch/testing/_asserts.py
+++ b/torch/testing/_asserts.py
@@ -77,6 +77,7 @@
if actual.dtype not in (torch.complex32, torch.complex64, torch.complex128):
return check_tensors(actual, expected, equal_nan=equal_nan, **kwargs)
+
if relaxed_complex_nan:
actual, expected = [
t.clone().masked_fill(
@@ -98,18 +99,87 @@
return wrapper
-def _check_supported_tensor(input: Tensor) -> Optional[_TestingErrorMeta]:
- """Checks if the tensors are supported by the current infrastructure.
+def _check_sparse_coo_members_individually(
+ check_tensors: Callable[..., Optional[_TestingErrorMeta]]
+) -> Callable[..., Optional[_TestingErrorMeta]]:
+ """Decorates strided tensor check functions to individually handle sparse COO members.
- All checks are temporary and will be relaxed in the future.
+ If the inputs are not sparse COO, this decorator is a no-op.
+
+ Args:
+ check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Tensor check function for strided tensors.
+ """
+
+ @functools.wraps(check_tensors)
+ def wrapper(actual: Tensor, expected: Tensor, **kwargs: Any) -> Optional[_TestingErrorMeta]:
+ if not actual.is_sparse:
+ return check_tensors(actual, expected, **kwargs)
+
+ if actual._nnz() != expected._nnz():
+ return _TestingErrorMeta(
+ AssertionError, f"The number of specified values does not match: {actual._nnz()} != {expected._nnz()}"
+ )
+
+ kwargs_equal = dict(kwargs, rtol=0, atol=0)
+ error_meta = check_tensors(actual._indices(), expected._indices(), **kwargs_equal)
+ if error_meta:
+ return error_meta.amend_msg(postfix="\n\nThe failure occurred for the indices.")
+
+ error_meta = check_tensors(actual._values(), expected._values(), **kwargs)
+ if error_meta:
+ return error_meta.amend_msg(postfix="\n\nThe failure occurred for the values.")
+
+ return None
+
+ return wrapper
+
+
+def _check_sparse_csr_members_individually(
+ check_tensors: Callable[..., Optional[_TestingErrorMeta]]
+) -> Callable[..., Optional[_TestingErrorMeta]]:
+ """Decorates strided tensor check functions to individually handle sparse CSR members.
+
+ If the inputs are not sparse CSR, this decorator is a no-op.
+
+ Args:
+ check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Tensor check function for strided
+ tensors.
+ """
+
+ @functools.wraps(check_tensors)
+ def wrapper(actual: Tensor, expected: Tensor, **kwargs: Any) -> Optional[_TestingErrorMeta]:
+ if not actual.is_sparse_csr:
+ return check_tensors(actual, expected, **kwargs)
+
+ kwargs_equal = dict(kwargs, rtol=0, atol=0)
+ error_meta = check_tensors(actual.crow_indices(), expected.crow_indices(), **kwargs_equal)
+ if error_meta:
+ return error_meta.amend_msg(postfix="\n\nThe failure occurred for the crow_indices.")
+
+ error_meta = check_tensors(actual.col_indices(), expected.col_indices(), **kwargs_equal)
+ if error_meta:
+ return error_meta.amend_msg(postfix="\n\nThe failure occurred for the col_indices.")
+
+ error_meta = check_tensors(actual.values(), expected.values(), **kwargs)
+ if error_meta:
+ return error_meta.amend_msg(postfix="\n\nThe failure occurred for the values.")
+
+ return None
+
+ return wrapper
+
+
+def _check_supported_tensor(input: Tensor) -> Optional[_TestingErrorMeta]:
+ """Checks if the tensor is supported by the current infrastructure.
Returns:
(Optional[_TestingErrorMeta]): If check did not pass.
"""
if input.is_quantized:
return _TestingErrorMeta(UsageError, "Comparison for quantized tensors is not supported yet.")
- if input.is_sparse:
- return _TestingErrorMeta(UsageError, "Comparison for sparse tensors is not supported yet.")
+
+ if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: # type: ignore[attr-defined]
+ return _TestingErrorMeta(UsageError, f"Unsupported tensor layout {input.layout}")
return None
@@ -121,11 +191,13 @@
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
+ check_is_coalesced: bool = True,
) -> Optional[_TestingErrorMeta]:
"""Checks if the attributes of two tensors match.
- Always checks the :attr:`~torch.Tensor.shape`. Checks for :attr:`~torch.Tensor.device`,
- :attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
+ Always checks the :attr:`~torch.Tensor.shape` and :attr:`~torch.Tensor.layout`. Checks for
+ :attr:`~torch.Tensor.device`, :attr:`~torch.Tensor.dtype`, :meth:`~torch.Tensor.stride` if the tensors are strided,
+ and :meth:`~torch.tensor.is_coalesced` if the tensors are sparse COO are optional and can be disabled.
Args:
actual (Tensor): Actual tensor.
@@ -134,8 +206,10 @@
same :attr:`~torch.Tensor.device`.
check_dtype (bool): If ``True`` (default), checks that both :attr:`actual` and :attr:`expected` have the same
``dtype``.
- check_stride (bool): If ``True`` (default), checks that both :attr:`actual` and :attr:`expected` have the same
- stride.
+ check_stride (bool): If ``True`` (default) and the tensors are strided, checks that both :attr:`actual` and
+ :attr:`expected` have the same stride.
+ check_is_coalesced (bool): If ``True`` (default) and the tensors are sparse COO, checks that both
+ :attr:`actual` and :attr:`expected` are either coalesced or uncoalesced.
Returns:
(Optional[_TestingErrorMeta]): If checks did not pass.
@@ -145,15 +219,21 @@
if actual.shape != expected.shape:
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("shape", actual.shape, expected.shape))
+ if actual.layout != expected.layout:
+ return _TestingErrorMeta(AssertionError, msg_fmtstr.format("layout", actual.layout, expected.layout))
+ elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride():
+ return _TestingErrorMeta(AssertionError, msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
+ elif actual.layout == torch.sparse_coo and check_is_coalesced and actual.is_coalesced() != expected.is_coalesced():
+ return _TestingErrorMeta(
+ AssertionError, msg_fmtstr.format("is_coalesced()", actual.is_coalesced(), expected.is_coalesced())
+ )
+
if check_device and actual.device != expected.device:
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("device", actual.device, expected.device))
if check_dtype and actual.dtype != expected.dtype:
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("dtype", actual.dtype, expected.dtype))
- if check_stride and actual.stride() != expected.stride():
- return _TestingErrorMeta(AssertionError, msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
-
return None
@@ -181,6 +261,10 @@
actual = actual.to(dtype)
expected = expected.to(dtype)
+ if actual.is_sparse and actual.is_coalesced() != expected.is_coalesced():
+ actual = actual.coalesce()
+ expected = expected.coalesce()
+
return actual, expected
@@ -239,6 +323,8 @@
)
+@_check_sparse_coo_members_individually
+@_check_sparse_csr_members_individually
@_check_complex_components_individually
def _check_values_close(
actual: Tensor,
@@ -292,6 +378,7 @@
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
+ check_is_coalesced: bool = True,
msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
) -> Optional[_TestingErrorMeta]:
r"""Checks that the values of :attr:`actual` and :attr:`expected` are close.
@@ -323,7 +410,12 @@
rtol, atol = _get_default_rtol_and_atol(actual, expected)
error_meta = _check_attributes_equal(
- actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
+ actual,
+ expected,
+ check_device=check_device,
+ check_dtype=check_dtype,
+ check_stride=check_stride,
+ check_is_coalesced=check_is_coalesced,
)
if error_meta:
return error_meta
@@ -537,11 +629,12 @@
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
+ check_is_coalesced: bool = True,
msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
) -> None:
r"""Asserts that :attr:`actual` and :attr:`expected` are close.
- If :attr:`actual` and :attr:`expected` are real-valued and finite, they are considered close if
+ If :attr:`actual` and :attr:`expected` are strided, real-valued, and finite, they are considered close if
.. math::
@@ -555,6 +648,12 @@
If :attr:`actual` and :attr:`expected` are complex-valued, they are considered close if both their real and
imaginary components are considered close according to the definition above.
+ If :attr:`actual` and :attr:`expected` are sparse (either having COO or CSR layout), their strided members are
+ checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout,
+ are always checked for equality whereas the values are checked for closeness according to the definition above.
+ Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if
+ :attr:`check_is_coalesced` is ``True``).
+
:attr:`actual` and :attr:`expected` can be :class:`~torch.Tensor`'s or any array-or-scalar-like of the same type,
from which :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. In addition, :attr:`actual` and
:attr:`expected` can be :class:`~collections.abc.Sequence`'s or :class:`~collections.abc.Mapping`'s in which case
@@ -576,24 +675,30 @@
check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
:func:`torch.promote_types`) before being compared.
- check_stride (bool): If ``True`` (default), asserts that corresponding tensors have the same stride.
+ check_stride (bool): If ``True`` (default) and corresponding tensors are strided, asserts that they have the
+ same stride.
+ check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both
+ :attr:`actual` and :attr:`expected` are either coalesced or uncoalesced. If this check is disabled,
+ tensors are :meth:`~torch.Tensor.coalesce`'ed before being compared.
msg (Optional[Union[str, Callable[[Tensor, Tensor, DiagnosticInfo], str]]]): Optional error message to use if
the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called
with the mismatching tensors and a namespace of diagnostic info about the mismatches. See below for details.
Raises:
UsageError: If a :class:`torch.Tensor` can't be constructed from an array-or-scalar-like.
- UsageError: If any tensor is quantized or sparse. This is a temporary restriction and will be relaxed in the
- future.
+ UsageError: If any tensor is quantized. This is a temporary restriction and will be relaxed in the future.
UsageError: If only :attr:`rtol` or :attr:`atol` is specified.
AssertionError: If corresponding array-likes have different types.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
+ AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.layout`.
AssertionError: If :attr:`check_device`, but corresponding tensors are not on the same
:attr:`~torch.Tensor.device`.
AssertionError: If :attr:`check_dtype`, but corresponding tensors do not have the same ``dtype``.
- AssertionError: If :attr:`check_stride`, but corresponding tensors do not have the same stride.
+ AssertionError: If :attr:`check_stride`, but corresponding strided tensors do not have the same stride.
+ AssertionError: If :attr:`check_is_coalesced`, but corresponding sparse COO tensors are not both either
+ coalesced or uncoalesced.
AssertionError: If the values of corresponding tensors are not close.
The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. Note that the ``dtype``
@@ -743,6 +848,7 @@
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
+ check_is_coalesced=check_is_coalesced,
msg=msg,
)
if error_meta: