add support for None in assert_close (#67795)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67795

Closes #61035.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D32532207

Pulled By: mruberry

fbshipit-source-id: 6a2b4245e0effce4ddea7d89eca63e3b163951a7
diff --git a/test/test_testing.py b/test/test_testing.py
index 4617d84..8c7b4ad 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -937,6 +937,21 @@
         for fn in assert_close_with_inputs(actual, expected):
             fn()
 
+    def test_none(self):
+        actual = expected = None
+
+        for fn in assert_close_with_inputs(actual, expected):
+            fn()
+
+    def test_none_mismatch(self):
+        expected = None
+
+        for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
+            for fn in assert_close_with_inputs(actual, expected):
+                with self.assertRaises(AssertionError):
+                    fn()
+
+
     def test_docstring_examples(self):
         finder = doctest.DocTestFinder(verbose=False)
         runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index 847398b..219b30d 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -321,6 +321,20 @@
             raise self._make_error_meta(AssertionError, f"{self.actual} != {self.expected}")
 
 
+class NonePair(Pair):
+    """Pair for ``None`` inputs."""
+
+    def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
+        if not (actual is None or expected is None):
+            raise UnsupportedInputs()
+
+        super().__init__(actual, expected, **other_parameters)
+
+    def compare(self) -> None:
+        if not (self.actual is None and self.expected is None):
+            raise self._make_error_meta(AssertionError, f"None mismatch: {self.actual} is not {self.expected}")
+
+
 class BooleanPair(Pair):
     """Pair for :class:`bool` inputs.
 
@@ -1128,7 +1142,12 @@
     assert_equal(
         actual,
         expected,
-        pair_types=(BooleanPair, NumberPair, TensorLikePair),
+        pair_types=(
+            NonePair,
+            BooleanPair,
+            NumberPair,
+            TensorLikePair,
+        ),
         allow_subclasses=allow_subclasses,
         rtol=rtol,
         atol=atol,