re-add dynamic error messages to assert_close

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77601

Approved by: https://github.com/mruberry
diff --git a/test/test_testing.py b/test/test_testing.py
index 25f53e5..6c4d687 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -945,7 +945,7 @@
             with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
                 fn(rtol=0.0, atol=atol)
 
-    def test_msg(self):
+    def test_msg_str(self):
         msg = "Custom error message!"
 
         actual = torch.tensor(1)
@@ -955,6 +955,16 @@
             with self.assertRaisesRegex(AssertionError, msg):
                 fn(msg=msg)
 
+    def test_msg_callable(self):
+        msg = "Custom error message"
+
+        actual = torch.tensor(1)
+        expected = torch.tensor(2)
+
+        for fn in assert_close_with_inputs(actual, expected):
+            with self.assertRaisesRegex(AssertionError, msg):
+                fn(msg=lambda _: msg)
+
 
 class TestAssertCloseContainer(TestCase):
     def test_sequence_mismatching_len(self):
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index a614060..8b243d4 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -28,10 +28,14 @@
         self.msg = msg
         self.id = id
 
-    def to_error(self) -> Exception:
-        msg = self.msg
-        if self.id:
-            msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
+    def to_error(self, msg: Optional[Union[str, Callable[[str], str]]] = None) -> Exception:
+        if not isinstance(msg, str):
+            generated_msg = self.msg
+            if self.id:
+                generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
+
+            msg = msg(generated_msg) if callable(msg) else generated_msg
+
         return self.type(msg)
 
 
@@ -159,7 +163,7 @@
     msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
     msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
 
-    return msg
+    return msg.strip()
 
 
 def make_scalar_mismatch_msg(
@@ -283,13 +287,11 @@
         expected: Any,
         *,
         id: Tuple[Any, ...] = (),
-        msg: Optional[str] = None,
         **unknown_parameters: Any,
     ) -> None:
         self.actual = actual
         self.expected = expected
         self.id = id
-        self.msg = msg
         self._unknown_parameters = unknown_parameters
 
     @staticmethod
@@ -298,18 +300,15 @@
         if not all(isinstance(input, cls) for input in inputs):
             raise UnsupportedInputs()
 
-    def _make_error_meta(self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()) -> ErrorMeta:
+    def _make_error_meta(self, type: Type[Exception], msg: str) -> ErrorMeta:
         """Makes an :class:`ErrorMeta` from a given exception type and message and the stored id.
 
-        If ``type`` is an :class:`AssertionError` and a ``msg`` was supplied during instantiation, this will override
-        the passed ``msg``.
-
         .. warning::
 
             Since this method uses instance attributes of :class:`Pair`, it should not be used before the
             ``super().__init__(...)`` call in the constructor.
         """
-        return ErrorMeta(type, self.msg if self.msg and type is AssertionError else msg, id=self.id or id)
+        return ErrorMeta(type, msg, id=self.id)
 
     @abc.abstractmethod
     def compare(self) -> None:
@@ -1028,6 +1027,7 @@
     pair_types: Sequence[Type[Pair]] = (ObjectPair,),
     sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
     mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
+    msg: Optional[Union[str, Callable[[str], str]]] = None,
     **options: Any,
 ) -> None:
     """Asserts that inputs are equal.
@@ -1083,7 +1083,7 @@
         return
 
     # TODO: compose all metas into one AssertionError
-    raise error_metas[0].to_error()
+    raise error_metas[0].to_error(msg)
 
 
 def assert_close(
@@ -1098,7 +1098,7 @@
     check_dtype: bool = True,
     check_layout: bool = True,
     check_stride: bool = False,
-    msg: Optional[str] = None,
+    msg: Optional[Union[str, Callable[[str], str]]] = None,
 ):
     r"""Asserts that ``actual`` and ``expected`` are close.
 
@@ -1159,7 +1159,9 @@
             check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
             compared.
         check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
-        msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison.
+        msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
+            the comparison. Can also passed as callable in which case it will be called with the generated message and
+            should return the new message.
 
     Raises:
         ValueError: If no :class:`torch.Tensor` can be constructed from an input.
@@ -1310,6 +1312,22 @@
         Traceback (most recent call last):
         ...
         AssertionError: Argh, the tensors are not close!
+        >>> # If msg is a callable, it can be used to augment the generated message with
+        >>> # extra information
+        >>> torch.testing.assert_close(
+        ...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
+        ... )
+        Traceback (most recent call last):
+        ...
+        AssertionError: Header
+        <BLANKLINE>
+        Tensor-likes are not close!
+        <BLANKLINE>
+        Mismatched elements: 2 / 3 (66.7%)
+        Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
+        Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
+        <BLANKLINE>
+        Footer
     """
     # Hide this function from `pytest`'s traceback
     __tracebackhide__ = True