resolve conjugate bit in `torch.testing.assert_close` (#60522)
Summary:
We need to resolve the conjugate bit for complex tensors, because otherwise we may not be able to access the imaginary component:
```python
>>> torch.tensor(complex(1, 1)).conj().imag
RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60522
Reviewed By: ngimel
Differential Revision: D29353095
Pulled By: mruberry
fbshipit-source-id: c36eaf883dd55041166f692f7b1d35cd2a34acfb
diff --git a/test/test_testing.py b/test/test_testing.py
index 8e47eaf..3fead51 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -1075,6 +1075,13 @@
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
fn()
+ def test_matching_conjugate_bit(self):
+ actual = torch.tensor(complex(1, 1)).conj()
+ expected = torch.tensor(complex(1, -1))
+
+ for fn in assert_close_with_inputs(actual, expected):
+ fn()
+
class TestAssertCloseSparseCOO(TestCase):
def test_matching_coalesced(self):
diff --git a/torch/testing/_asserts.py b/torch/testing/_asserts.py
index 349988d..6c7c826 100644
--- a/torch/testing/_asserts.py
+++ b/torch/testing/_asserts.py
@@ -78,6 +78,9 @@
if actual.dtype not in (torch.complex32, torch.complex64, torch.complex128):
return check_tensors(actual, expected, equal_nan=equal_nan, **kwargs)
+ actual = actual.resolve_conj()
+ expected = expected.resolve_conj()
+
if relaxed_complex_nan:
actual, expected = [
t.clone().masked_fill(