Expand the coverage of test_addmm and test_addmm_sizes (#43831)

Summary:
- This test is very fast and very important, so it makes no sense in marking it as slowTest
- This test is should also run on CUDA
- This test should check alpha and beta support
- This test should check `out=` support
- manual computation should use list instead of index_put because list is much faster
- precision for TF32 needs to be fixed. Will do it in future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43831

Reviewed By: ailzhang

Differential Revision: D23435032

Pulled By: ngimel

fbshipit-source-id: d1b8350addf1e2fe180fdf3df243f38d95aa3f5a
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index e02b77c..c78029d 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -74,7 +74,7 @@
 
   if (&result != &self) {
     at::native::resize_as_(result, self_);
-    if (beta.to<double>() != 0.0) {
+    if (beta.toComplexDouble() != 0.0) {
       at::native::copy_(result, self_);
     }
   }
diff --git a/test/test_torch.py b/test/test_torch.py
index ffff551..2e79582 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16379,48 +16379,48 @@
         for use_out, row_major, incx, incy, lda_tail in product((False, True), (False, True), (1, 2), (1, 2), (0, 1)):
             _test(use_out, row_major, incx, incy, lda_tail)
 
-    @slowTest
-    @onlyCPU
-    def test_addmm(self, device):
-        dtypes = {
-            torch.double: 1e-8,
-            torch.float: 1e-4,
-            torch.bfloat16: 1e-1,
-            torch.half: 1e-1,
-            torch.cfloat: 1e-4,
-            torch.cdouble: 1e-8
-        }
-        for dtype, prec in dtypes.items():
-            M = torch.randn(10, 25).to(device=device, dtype=dtype)
-            m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
-            m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
-            res1 = torch.addmm(M, m1, m2)
-            res2 = torch.zeros(10, 25, device=device, dtype=dtype)
-            res2 += M
-            for i in range(10):
-                for j in range(25):
-                    for k in range(50):
-                        res2[i, j] += m1[i, k] * m2[k, j]
-            self.assertEqual(res1, res2, atol=prec, rtol=0)
+    def _test_addmm(self, M, m1, m2):
+        dtype = M.dtype
+        numpy_dtype = dtype
+        if dtype in {torch.bfloat16}:
+            numpy_dtype = torch.float
+        if dtype.is_complex:
+            alpha = 0.9 + 0.3j
+            beta = 0.5 + 0.6j
+        else:
+            alpha = 1.2
+            beta = 0.8
+        res1 = torch.addmm(M, m1, m2, alpha=alpha, beta=beta)
+        res2 = torch.full_like(res1, math.nan)
+        torch.addmm(M, m1, m2, alpha=alpha, beta=beta, out=res2)
+        res3 = (beta * M).to(numpy_dtype).cpu().numpy() + alpha * (
+            m1.to(numpy_dtype).cpu().numpy() @ m2.to(numpy_dtype).cpu().numpy())
+        res3 = torch.from_numpy(res3).to(dtype)
+        self.assertEqual(res1, res2)
+        self.assertEqual(res1, res3)
+
+    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
+                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
+    @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=False))
+    @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    def test_addmm(self, device, dtype):
+        M = torch.randn(10, 25).to(device=device, dtype=dtype)
+        m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
+        m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
+        self._test_addmm(M, m1, m2)
 
         # Test 0-strided
-        for dtype, prec in dtypes.items():
-            M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
-            m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
-            m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
-            res1 = torch.addmm(M, m1, m2)
-            res2 = torch.zeros(10, 25, device=device, dtype=dtype)
-            res2 += M
-            for i in range(10):
-                for j in range(25):
-                    for k in range(50):
-                        res2[i, j] += m1[i, k] * m2[k, j]
-            self.assertEqual(res1, res2, atol=prec, rtol=0)
+        M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
+        m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
+        m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
+        self._test_addmm(M, m1, m2)
 
     @dtypes(torch.float, torch.double)
     @dtypesIfCUDA(*([torch.float, torch.double] +
                     ([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
     @tf32_on_and_off(0.005)
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_addmm_sizes(self, device, dtype):
         for m in [0, 1, 25]:
             for n in [0, 1, 10]:
@@ -16428,14 +16428,7 @@
                     M = torch.randn(n, m, device=device, dtype=dtype)
                     m1 = torch.randn(n, k, device=device, dtype=dtype)
                     m2 = torch.randn(k, m, device=device, dtype=dtype)
-                    res1 = torch.addmm(M, m1, m2)
-                    res2 = torch.zeros(n, m, device=device, dtype=dtype)
-                    res2 += M
-                    for i in range(n):
-                        for j in range(m):
-                            for l in range(k):
-                                res2[i, j] += m1[i, l] * m2[l, j]
-                    self.assertEqual(res1, res2)
+                    self._test_addmm(M, m1, m2)
 
     def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
         def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):