blob: a9ef0c58cb9dccd945333e22c421ddb9c955656e [file] [log] [blame]
"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
we don't internalize without warning, but still go through a deprecation cycle.
"""
import functools
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
__all__ = ["assert_allclose"]
def warn_deprecated(
instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]
) -> Callable:
def outer_wrapper(fn: Callable) -> Callable:
name = fn.__name__
head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
@functools.wraps(fn)
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
return_value = fn(*args, **kwargs)
tail = (
instructions(name, args, kwargs, return_value)
if callable(instructions)
else instructions
)
msg = (head + tail).strip()
warnings.warn(msg, FutureWarning)
return return_value
return inner_wrapper
return outer_wrapper
_DTYPE_PRECISIONS = {
torch.float16: (1e-3, 1e-3),
torch.float32: (1e-4, 1e-5),
torch.float64: (1e-5, 1e-8),
}
def _get_default_rtol_and_atol(
actual: torch.Tensor, expected: torch.Tensor
) -> Tuple[float, float]:
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
@warn_deprecated(
"Use torch.testing.assert_close() instead. "
"For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
)
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: str = "",
) -> None:
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected, dtype=actual.dtype)
if rtol is None and atol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=True,
check_dtype=False,
check_stride=False,
msg=msg or None,
)