| import math |
| import torch |
| from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY |
| import unittest |
| |
| if TEST_NUMPY: |
| import numpy as np |
| |
| devices = (torch.device('cpu'), torch.device('cuda:0')) |
| |
| |
| class TestComplexTensor(TestCase): |
| def test_to_list_with_complex_64(self): |
| # test that the complex float tensor has expected values and |
| # there's no garbage value in the resultant list |
| self.assertEqual(torch.zeros((2, 2), dtype=torch.complex64).tolist(), [[0j, 0j], [0j, 0j]]) |
| |
| @unittest.skipIf(not TEST_NUMPY, "Numpy not found") |
| def test_exp(self): |
| def exp_fn(dtype): |
| a = torch.tensor(1j, dtype=dtype) * torch.arange(18) / 3 * math.pi |
| expected = np.exp(a.numpy()) |
| actual = torch.exp(a) |
| self.assertEqual(actual, torch.from_numpy(expected)) |
| |
| exp_fn(torch.complex64) |
| exp_fn(torch.complex128) |
| |
| def test_dtype_inference(self): |
| # issue: https://github.com/pytorch/pytorch/issues/36834 |
| torch.set_default_dtype(torch.double) |
| x = torch.tensor([3., 3. + 5.j]) |
| self.assertEqual(x.dtype, torch.cdouble) |
| |
| if __name__ == '__main__': |
| run_tests() |