Fix torch complex exp CPU implementation (#35532) (#35715)
Summary:
There was a permutation operation missing in each of the complex vector files. I also added some test cases, the last two of which fail under the current implementation. This PR fixes that: all the testcases pass.
Fixes https://github.com/pytorch/pytorch/issues/35532
dylanbespalko
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35715
Differential Revision: D20857024
Pulled By: anjali411
fbshipit-source-id: 4eecd8f0863faa838300951626f26b89e6cc9c6b
diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h
index fc6fac1..149eb95 100644
--- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h
+++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h
@@ -210,7 +210,8 @@
exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a)
auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
- auto cos_sin = _mm256_blend_pd(sin_cos.y, sin_cos.x, 0x0A); //cos(b) sin(b)
+ auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05),
+ sin_cos.x, 0x0A); //cos(b) sin(b)
return _mm256_mul_pd(exp, cos_sin);
}
Vec256<std::complex<double>> expm1() const {
diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h
index 92711c9..7886f13 100644
--- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h
+++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h
@@ -248,7 +248,8 @@
exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a)
auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
- auto cos_sin = _mm256_blend_ps(sin_cos.y, sin_cos.x, 0xAA); //cos(b) sin(b)
+ auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1),
+ sin_cos.x, 0xAA); //cos(b) sin(b)
return _mm256_mul_ps(exp, cos_sin);
}
Vec256<std::complex<float>> expm1() const {
diff --git a/test/test_complex.py b/test/test_complex.py
index 3912fe7..d17a927 100644
--- a/test/test_complex.py
+++ b/test/test_complex.py
@@ -1,13 +1,31 @@
+import math
import torch
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
+import unittest
+
+if TEST_NUMPY:
+ import numpy as np
devices = (torch.device('cpu'), torch.device('cuda:0'))
+
class TestComplexTensor(TestCase):
def test_to_list_with_complex_64(self):
# test that the complex float tensor has expected values and
# there's no garbage value in the resultant list
self.assertEqual(torch.zeros((2, 2), dtype=torch.complex64).tolist(), [[0j, 0j], [0j, 0j]])
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ def test_exp(self):
+ def exp_fn(dtype):
+ a = torch.tensor(1j, dtype=dtype) * torch.arange(18) / 3 * math.pi
+ expected = np.exp(a.numpy())
+ actual = torch.exp(a)
+ self.assertEqual(actual, torch.from_numpy(expected))
+
+ exp_fn(torch.complex64)
+ exp_fn(torch.complex128)
+
+
if __name__ == '__main__':
run_tests()