add a equality comparison helper for assert_close internals (#69750)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69750
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D33542993
Pulled By: mruberry
fbshipit-source-id: 0de0559c33ec0f1dad205113cb363a652140b62d
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index be8bba0..47ae53e 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -709,7 +709,7 @@
elif actual.is_sparse_csr:
compare_fn = self._compare_sparse_csr_values
else:
- compare_fn = self._compare_regular_values
+ compare_fn = self._compare_regular_values_close
compare_fn(actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan)
@@ -724,7 +724,7 @@
the individual quantization parameters for closeness and the integer representation for equality can be
found in https://github.com/pytorch/pytorch/issues/68548.
"""
- return self._compare_regular_values(
+ return self._compare_regular_values_close(
actual.dequantize(),
expected.dequantize(),
rtol=rtol,
@@ -761,15 +761,12 @@
),
)
- self._compare_regular_values(
+ self._compare_regular_values_equal(
actual._indices(),
expected._indices(),
- rtol=0,
- atol=0,
- equal_nan=False,
identifier="Sparse COO indices",
)
- self._compare_regular_values(
+ self._compare_regular_values_close(
actual._values(),
expected._values(),
rtol=rtol,
@@ -797,23 +794,17 @@
),
)
- self._compare_regular_values(
+ self._compare_regular_values_equal(
actual.crow_indices(),
expected.crow_indices(),
- rtol=0,
- atol=0,
- equal_nan=False,
identifier="Sparse CSR crow_indices",
)
- self._compare_regular_values(
+ self._compare_regular_values_equal(
actual.col_indices(),
expected.col_indices(),
- rtol=0,
- atol=0,
- equal_nan=False,
identifier="Sparse CSR col_indices",
)
- self._compare_regular_values(
+ self._compare_regular_values_close(
actual.values(),
expected.values(),
rtol=rtol,
@@ -822,7 +813,18 @@
identifier="Sparse CSR values",
)
- def _compare_regular_values(
+ def _compare_regular_values_equal(
+ self,
+ actual: torch.Tensor,
+ expected: torch.Tensor,
+ *,
+ equal_nan: bool = False,
+ identifier: Optional[Union[str, Callable[[str], str]]] = None,
+ ) -> None:
+ """Checks if the values of two tensors are equal."""
+ self._compare_regular_values_close(actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier)
+
+ def _compare_regular_values_close(
self,
actual: torch.Tensor,
expected: torch.Tensor,