| import torch | |
| from torch.testing._internal.common_utils import TestCase, run_tests | |
| devices = (torch.device('cpu'), torch.device('cuda:0')) | |
| class TestComplexTensor(TestCase): | |
| def test_to_list_with_complex_64(self): | |
| # test that the complex float tensor has expected values and | |
| # there's no garbage value in the resultant list | |
| self.assertEqual(torch.zeros((2, 2), dtype=torch.complex64).tolist(), [[0j, 0j], [0j, 0j]]) | |
| if __name__ == '__main__': | |
| run_tests() |