re-add dynamic error messages to assert_close
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77601
Approved by: https://github.com/mruberry
diff --git a/test/test_testing.py b/test/test_testing.py
index 25f53e5..6c4d687 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -945,7 +945,7 @@
with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
fn(rtol=0.0, atol=atol)
- def test_msg(self):
+ def test_msg_str(self):
msg = "Custom error message!"
actual = torch.tensor(1)
@@ -955,6 +955,16 @@
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=msg)
+ def test_msg_callable(self):
+ msg = "Custom error message"
+
+ actual = torch.tensor(1)
+ expected = torch.tensor(2)
+
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaisesRegex(AssertionError, msg):
+ fn(msg=lambda _: msg)
+
class TestAssertCloseContainer(TestCase):
def test_sequence_mismatching_len(self):
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index a614060..8b243d4 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -28,10 +28,14 @@
self.msg = msg
self.id = id
- def to_error(self) -> Exception:
- msg = self.msg
- if self.id:
- msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
+ def to_error(self, msg: Optional[Union[str, Callable[[str], str]]] = None) -> Exception:
+ if not isinstance(msg, str):
+ generated_msg = self.msg
+ if self.id:
+ generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
+
+ msg = msg(generated_msg) if callable(msg) else generated_msg
+
return self.type(msg)
@@ -159,7 +163,7 @@
msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
- return msg
+ return msg.strip()
def make_scalar_mismatch_msg(
@@ -283,13 +287,11 @@
expected: Any,
*,
id: Tuple[Any, ...] = (),
- msg: Optional[str] = None,
**unknown_parameters: Any,
) -> None:
self.actual = actual
self.expected = expected
self.id = id
- self.msg = msg
self._unknown_parameters = unknown_parameters
@staticmethod
@@ -298,18 +300,15 @@
if not all(isinstance(input, cls) for input in inputs):
raise UnsupportedInputs()
- def _make_error_meta(self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()) -> ErrorMeta:
+ def _make_error_meta(self, type: Type[Exception], msg: str) -> ErrorMeta:
"""Makes an :class:`ErrorMeta` from a given exception type and message and the stored id.
- If ``type`` is an :class:`AssertionError` and a ``msg`` was supplied during instantiation, this will override
- the passed ``msg``.
-
.. warning::
Since this method uses instance attributes of :class:`Pair`, it should not be used before the
``super().__init__(...)`` call in the constructor.
"""
- return ErrorMeta(type, self.msg if self.msg and type is AssertionError else msg, id=self.id or id)
+ return ErrorMeta(type, msg, id=self.id)
@abc.abstractmethod
def compare(self) -> None:
@@ -1028,6 +1027,7 @@
pair_types: Sequence[Type[Pair]] = (ObjectPair,),
sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
+ msg: Optional[Union[str, Callable[[str], str]]] = None,
**options: Any,
) -> None:
"""Asserts that inputs are equal.
@@ -1083,7 +1083,7 @@
return
# TODO: compose all metas into one AssertionError
- raise error_metas[0].to_error()
+ raise error_metas[0].to_error(msg)
def assert_close(
@@ -1098,7 +1098,7 @@
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
- msg: Optional[str] = None,
+ msg: Optional[Union[str, Callable[[str], str]]] = None,
):
r"""Asserts that ``actual`` and ``expected`` are close.
@@ -1159,7 +1159,9 @@
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
compared.
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
- msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison.
+ msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
+ the comparison. Can also passed as callable in which case it will be called with the generated message and
+ should return the new message.
Raises:
ValueError: If no :class:`torch.Tensor` can be constructed from an input.
@@ -1310,6 +1312,22 @@
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
+ >>> # If msg is a callable, it can be used to augment the generated message with
+ >>> # extra information
+ >>> torch.testing.assert_close(
+ ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
+ ... )
+ Traceback (most recent call last):
+ ...
+ AssertionError: Header
+ <BLANKLINE>
+ Tensor-likes are not close!
+ <BLANKLINE>
+ Mismatched elements: 2 / 3 (66.7%)
+ Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
+ Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
+ <BLANKLINE>
+ Footer
"""
# Hide this function from `pytest`'s traceback
__tracebackhide__ = True