add support for None in assert_close (#67795)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67795
Closes #61035.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D32532207
Pulled By: mruberry
fbshipit-source-id: 6a2b4245e0effce4ddea7d89eca63e3b163951a7
diff --git a/test/test_testing.py b/test/test_testing.py
index 4617d84..8c7b4ad 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -937,6 +937,21 @@
for fn in assert_close_with_inputs(actual, expected):
fn()
+ def test_none(self):
+ actual = expected = None
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn()
+
+ def test_none_mismatch(self):
+ expected = None
+
+ for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
+ for fn in assert_close_with_inputs(actual, expected):
+ with self.assertRaises(AssertionError):
+ fn()
+
+
def test_docstring_examples(self):
finder = doctest.DocTestFinder(verbose=False)
runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index 847398b..219b30d 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -321,6 +321,20 @@
raise self._make_error_meta(AssertionError, f"{self.actual} != {self.expected}")
+class NonePair(Pair):
+ """Pair for ``None`` inputs."""
+
+ def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
+ if not (actual is None or expected is None):
+ raise UnsupportedInputs()
+
+ super().__init__(actual, expected, **other_parameters)
+
+ def compare(self) -> None:
+ if not (self.actual is None and self.expected is None):
+ raise self._make_error_meta(AssertionError, f"None mismatch: {self.actual} is not {self.expected}")
+
+
class BooleanPair(Pair):
"""Pair for :class:`bool` inputs.
@@ -1128,7 +1142,12 @@
assert_equal(
actual,
expected,
- pair_types=(BooleanPair, NumberPair, TensorLikePair),
+ pair_types=(
+ NonePair,
+ BooleanPair,
+ NumberPair,
+ TensorLikePair,
+ ),
allow_subclasses=allow_subclasses,
rtol=rtol,
atol=atol,