Expand the coverage of test_addmm and test_addmm_sizes (#43831)
Summary:
- This test is very fast and very important, so it makes no sense in marking it as slowTest
- This test is should also run on CUDA
- This test should check alpha and beta support
- This test should check `out=` support
- manual computation should use list instead of index_put because list is much faster
- precision for TF32 needs to be fixed. Will do it in future PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43831
Reviewed By: ailzhang
Differential Revision: D23435032
Pulled By: ngimel
fbshipit-source-id: d1b8350addf1e2fe180fdf3df243f38d95aa3f5a
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index e02b77c..c78029d 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -74,7 +74,7 @@
if (&result != &self) {
at::native::resize_as_(result, self_);
- if (beta.to<double>() != 0.0) {
+ if (beta.toComplexDouble() != 0.0) {
at::native::copy_(result, self_);
}
}
diff --git a/test/test_torch.py b/test/test_torch.py
index ffff551..2e79582 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16379,48 +16379,48 @@
for use_out, row_major, incx, incy, lda_tail in product((False, True), (False, True), (1, 2), (1, 2), (0, 1)):
_test(use_out, row_major, incx, incy, lda_tail)
- @slowTest
- @onlyCPU
- def test_addmm(self, device):
- dtypes = {
- torch.double: 1e-8,
- torch.float: 1e-4,
- torch.bfloat16: 1e-1,
- torch.half: 1e-1,
- torch.cfloat: 1e-4,
- torch.cdouble: 1e-8
- }
- for dtype, prec in dtypes.items():
- M = torch.randn(10, 25).to(device=device, dtype=dtype)
- m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
- m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
- res1 = torch.addmm(M, m1, m2)
- res2 = torch.zeros(10, 25, device=device, dtype=dtype)
- res2 += M
- for i in range(10):
- for j in range(25):
- for k in range(50):
- res2[i, j] += m1[i, k] * m2[k, j]
- self.assertEqual(res1, res2, atol=prec, rtol=0)
+ def _test_addmm(self, M, m1, m2):
+ dtype = M.dtype
+ numpy_dtype = dtype
+ if dtype in {torch.bfloat16}:
+ numpy_dtype = torch.float
+ if dtype.is_complex:
+ alpha = 0.9 + 0.3j
+ beta = 0.5 + 0.6j
+ else:
+ alpha = 1.2
+ beta = 0.8
+ res1 = torch.addmm(M, m1, m2, alpha=alpha, beta=beta)
+ res2 = torch.full_like(res1, math.nan)
+ torch.addmm(M, m1, m2, alpha=alpha, beta=beta, out=res2)
+ res3 = (beta * M).to(numpy_dtype).cpu().numpy() + alpha * (
+ m1.to(numpy_dtype).cpu().numpy() @ m2.to(numpy_dtype).cpu().numpy())
+ res3 = torch.from_numpy(res3).to(dtype)
+ self.assertEqual(res1, res2)
+ self.assertEqual(res1, res3)
+
+ @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
+ torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
+ @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=False))
+ @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ def test_addmm(self, device, dtype):
+ M = torch.randn(10, 25).to(device=device, dtype=dtype)
+ m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
+ m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
+ self._test_addmm(M, m1, m2)
# Test 0-strided
- for dtype, prec in dtypes.items():
- M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
- m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
- m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
- res1 = torch.addmm(M, m1, m2)
- res2 = torch.zeros(10, 25, device=device, dtype=dtype)
- res2 += M
- for i in range(10):
- for j in range(25):
- for k in range(50):
- res2[i, j] += m1[i, k] * m2[k, j]
- self.assertEqual(res1, res2, atol=prec, rtol=0)
+ M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
+ m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
+ m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
+ self._test_addmm(M, m1, m2)
@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*([torch.float, torch.double] +
([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
@tf32_on_and_off(0.005)
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_addmm_sizes(self, device, dtype):
for m in [0, 1, 25]:
for n in [0, 1, 10]:
@@ -16428,14 +16428,7 @@
M = torch.randn(n, m, device=device, dtype=dtype)
m1 = torch.randn(n, k, device=device, dtype=dtype)
m2 = torch.randn(k, m, device=device, dtype=dtype)
- res1 = torch.addmm(M, m1, m2)
- res2 = torch.zeros(n, m, device=device, dtype=dtype)
- res2 += M
- for i in range(n):
- for j in range(m):
- for l in range(k):
- res2[i, j] += m1[i, l] * m2[l, j]
- self.assertEqual(res1, res2)
+ self._test_addmm(M, m1, m2)
def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):