blob: aae5af1fef6db705c4be66efcf1ab562bfbd481f [file] [log] [blame]
import sys
from collections import namedtuple
from typing import Any, Optional, Tuple, Type
import torch
from ._core import _unravel_index
__all__ = ["assert_tensors_equal", "assert_tensors_close"]
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
# differentiate between a test failure (there is a bug in the tested code) and a test error (there is a bug in the
# test). If pytest is the test runner, we use the built-in UsageError instead our custom one.
try:
# The module 'pytest' will be imported if the 'pytest' runner is used. This will only give false-positives in case
# a previously imported module already directly or indirectly imported 'pytest', but the test is run by another
# runner such as 'unittest'.
# 'mypy' is not able to handle this within a type annotation
# (see https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-aliases for details). In case
# 'UsageError' is used in an annotation, add a 'type: ignore[valid-type]' comment.
UsageError: Type[Exception] = sys.modules["pytest"].UsageError # type: ignore[attr-defined]
except (KeyError, AttributeError):
class UsageError(Exception): # type: ignore[no-redef]
pass
# This is copy-pasted from torch.testing._internal.common_utils.TestCase.dtype_precisions. With this we avoid a
# dependency on torch.testing._internal at import. See
# https://github.com/pytorch/pytorch/pull/54769#issuecomment-813174256 for details.
# {dtype: (rtol, atol)}
_DTYPE_PRECISIONS = {
torch.float16: (0.001, 1e-5),
torch.bfloat16: (0.016, 1e-5),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
def _check_are_tensors(actual: Any, expected: Any) -> Optional[AssertionError]:
"""Checks if both inputs are tensors.
Args:
actual (Any): Actual input.
expected (Any): Actual input.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
if not (isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)):
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
return None
def _check_supported_tensors(
actual: torch.Tensor,
expected: torch.Tensor,
) -> Optional[UsageError]: # type: ignore[valid-type]
"""Checks if the tensors are supported by the current infrastructure.
All checks are temporary and will be relaxed in the future.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
(Optional[UsageError]): If check did not pass.
"""
if any(t.dtype in (torch.complex32, torch.complex64, torch.complex128) for t in (actual, expected)):
return UsageError("Comparison for complex tensors is not supported yet.")
if any(t.is_quantized for t in (actual, expected)):
return UsageError("Comparison for quantized tensors is not supported yet.")
if any(t.is_sparse for t in (actual, expected)):
return UsageError("Comparison for sparse tensors is not supported yet.")
return None
def _check_attributes_equal(
actual: torch.Tensor,
expected: torch.Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> Optional[AssertionError]:
"""Checks if the attributes of two tensors match.
Always checks the :attr:`~torch.Tensor.shape`. Checks for :attr:`~torch.Tensor.device`,
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:meth:`~torch.Tensor.stride`.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
if actual.shape != expected.shape:
return AssertionError(msg_fmtstr.format("shape", actual.shape, expected.shape))
if check_device and actual.device != expected.device:
return AssertionError(msg_fmtstr.format("device", actual.device, expected.device))
if check_dtype and actual.dtype != expected.dtype:
return AssertionError(msg_fmtstr.format("dtype", actual.dtype, expected.dtype))
if check_stride and actual.stride() != expected.stride():
return AssertionError(msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
return None
def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Equalizes some attributes of two tensors for value comparison.
If :attr:`actual` and :attr:`expected`
- are not onn the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory, and
- do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types`.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
Tuple(torch.Tensor, torch.Tensor): Equalized tensors.
"""
if actual.device != expected.device:
actual = actual.cpu()
expected = expected.cpu()
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = actual.to(dtype)
expected = expected.to(dtype)
return actual, expected
_Trace = namedtuple(
"_Trace",
(
"total_elements",
"total_mismatches",
"mismatch_ratio",
"max_abs_diff",
"max_abs_diff_idx",
"max_rel_diff",
"max_rel_diff_idx",
),
)
def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches: torch.Tensor) -> _Trace:
"""Traces mismatches.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
mismatches (torch.Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
the location of mismatches.
Returns:
(NamedTuple): Mismatch diagnostics with the following fields:
- total_elements (int): Total number of values.
- total_mismatches (int): Total number of mismatches.
- mismatch_ratio (float): Quotient of total mismatches and total elements.
- max_abs_diff (Union[int, float]): Greatest absolute difference of :attr:`actual` and :attr:`expected`.
- max_abs_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest absolute difference.
- max_rel_diff (Union[int, float]): Greatest relative difference of :attr:`actual` and :attr:`expected`.
- max_rel_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest relative difference.
The returned type of ``max_abs_diff`` and ``max_rel_diff`` depends on the :attr:`~torch.Tensor.dtype` of
:attr:`actual` and :attr:`expected`.
"""
total_elements = mismatches.numel()
total_mismatches = torch.sum(mismatches).item()
mismatch_ratio = total_mismatches / total_elements
dtype = torch.float64 if actual.dtype.is_floating_point else torch.int64
a_flat = actual.flatten().to(dtype)
b_flat = expected.flatten().to(dtype)
abs_diff = torch.abs(a_flat - b_flat)
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
rel_diff = abs_diff / torch.abs(b_flat)
max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
return _Trace(
total_elements=total_elements,
total_mismatches=total_mismatches,
mismatch_ratio=mismatch_ratio,
max_abs_diff=max_abs_diff.item(),
max_abs_diff_idx=_unravel_index(max_abs_diff_flat_idx.item(), mismatches.shape),
max_rel_diff=max_rel_diff.item(),
max_rel_diff_idx=_unravel_index(max_rel_diff_flat_idx.item(), mismatches.shape),
)
def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optional[AssertionError]:
"""Checks if the values of two tensors are bitwise equal.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
mismatches = torch.ne(actual, expected)
if not torch.any(mismatches):
return None
trace = _trace_mismatches(actual, expected, mismatches)
return AssertionError(
f"Tensors are not equal!\n\n"
f"Mismatched elements: {trace.total_mismatches} / {trace.total_elements} ({trace.mismatch_ratio:.1%})\n"
f"Greatest absolute difference: {trace.max_abs_diff} at {trace.max_abs_diff_idx}\n"
f"Greatest relative difference: {trace.max_rel_diff} at {trace.max_rel_diff_idx}"
)
def _check_values_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol,
atol,
) -> Optional[AssertionError]:
"""Checks if the values of two tensors are close up to a desired tolerance.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
mismatches = ~torch.isclose(actual, expected, rtol=rtol, atol=atol)
if not torch.any(mismatches):
return None
trace = _trace_mismatches(actual, expected, mismatches)
return AssertionError(
f"Tensors are not close!\n\n"
f"Mismatched elements: {trace.total_mismatches} / {trace.total_elements} ({trace.mismatch_ratio:.1%})\n"
f"Greatest absolute difference: {trace.max_abs_diff} at {trace.max_abs_diff_idx} (up to {atol} allowed)\n"
f"Greatest relative difference: {trace.max_rel_diff} at {trace.max_rel_diff_idx} (up to {rtol} allowed)"
)
def assert_tensors_equal(
actual: torch.Tensor,
expected: torch.Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of two tensors are bitwise equal.
Optionally, checks that some attributes of both tensors are equal.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
their values are compared.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
stride.
Raises:
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
restriction and will be relaxed in the future.
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
AssertionError: If the values of :attr:`actual` and :attr:`expected` are not bitwise equal.
.. seealso::
To assert that the values in two tensors are are close but are not required to be bitwise equal, use
:func:`assert_tensors_close` instead.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
raise exc
exc = _check_supported_tensors(actual, expected)
if exc:
raise exc
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if exc:
raise exc
actual, expected = _equalize_attributes(actual, expected)
exc = _check_values_equal(actual, expected)
if exc:
raise exc
def assert_tensors_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of two tensors are close up to a desired tolerance.
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are bitwise
equal. Optionally, checks that some attributes of both tensors are equal.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
their values are compared.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
stride.
Raises:
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
restriction and will be relaxed in the future.
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
The following table displays the default ``rtol`` and ``atol`` for floating point :attr:`~torch.Tensor.dtype`'s.
For integer :attr:`~torch.Tensor.dtype`'s, ``rtol = atol = 0.0`` is used.
+===========================+============+==========+
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
+===========================+============+==========+
| :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
.. seealso::
To assert that the values in two tensors are bitwise equal, use :func:`assert_tensors_equal` instead.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
raise exc
exc = _check_supported_tensors(actual, expected)
if exc:
raise exc
if (rtol is None) ^ (atol is None):
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
raise UsageError(
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
)
elif rtol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if exc:
raise exc
actual, expected = _equalize_attributes(actual, expected)
if (rtol == 0.0) and (atol == 0.0):
exc = _check_values_equal(actual, expected)
else:
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
if exc:
raise exc