| # Owner(s): ["module: intel"] |
| |
| import itertools |
| import math |
| import random |
| from functools import partial |
| from itertools import product |
| |
| import numpy as np |
| |
| import torch |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_device_type import ( |
| dtypes, |
| instantiate_device_type_tests, |
| precisionOverride, |
| ) |
| from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase |
| |
| |
| class TestBasicGEMM(TestCase): |
| def _test_addmm_addmv( |
| self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None |
| ): |
| dtype = t.dtype |
| numpy_dtype = dtype |
| if dtype in {torch.bfloat16, torch.half}: |
| numpy_dtype = torch.float |
| if dtype.is_complex: |
| alpha = 0.9 + 0.3j if alpha is None else alpha |
| beta = 0.5 + 0.6j if beta is None else beta |
| else: |
| alpha = 1.2 if alpha is None else alpha |
| beta = 0.8 if beta is None else beta |
| if activation == "gelu": |
| res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) |
| else: |
| res1 = f(t, m, v, alpha=alpha, beta=beta) |
| res2 = torch.full_like(res1, math.nan) |
| if transpose_out: |
| res2 = res2.t().clone(memory_format=torch.contiguous_format).t() |
| if activation == "gelu": |
| f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) |
| else: |
| f(t, m, v, alpha=alpha, beta=beta, out=res2) |
| m.to(numpy_dtype).cpu().numpy() |
| v.to(numpy_dtype).cpu().numpy() |
| res3 = alpha * ( |
| m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy() |
| ) |
| if beta != 0: |
| res3 += (beta * t).to(numpy_dtype).cpu().numpy() |
| if activation == "relu": |
| res3 = res3 * (res3 > 0) |
| elif activation == "gelu": |
| res3_t = torch.from_numpy(res3).to(dtype) |
| approximate = "tanh" if t.is_cuda else "none" |
| res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) |
| res3 = res3_t.to(numpy_dtype).cpu().numpy() |
| else: |
| assert activation is None, f"unsupported activation {activation}" |
| res3 = torch.from_numpy(res3).to(dtype) |
| self.assertEqual(res1, res2) |
| self.assertEqual(res1, res3) |
| |
| def _test_addmm_impl(self, func, activation, device, dtype): |
| M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| self._test_addmm_addmv(func, M, m1, m2, activation=activation) |
| |
| # vector-shaped bias and beta=1 result in epilogue fusion in CUDA |
| V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) |
| |
| # Test 0-strided |
| M = ( |
| torch.randn(10, 1, device="cpu", dtype=torch.float32) |
| .to(dtype) |
| .expand(10, 25) |
| .to(device) |
| ) |
| m1 = ( |
| torch.randn(10, 1, device="cpu", dtype=torch.float32) |
| .to(dtype) |
| .expand(10, 50) |
| .to(device) |
| ) |
| m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| self._test_addmm_addmv(func, M, m1, m2, activation=activation) |
| |
| # Test beta=0, M=nan |
| M = ( |
| torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32) |
| .to(dtype) |
| .to(device) |
| ) |
| m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) |
| self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) |
| |
| # Test transpose |
| for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): |
| |
| def maybe_transpose(cond, m): |
| if not cond: |
| return m |
| return m.t().clone(memory_format=torch.contiguous_format).t() |
| |
| M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) |
| m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) |
| m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) |
| self._test_addmm_addmv( |
| func, M, m1, m2, transpose_out=t4, activation=activation |
| ) |
| |
| if t1: |
| # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) |
| self._test_addmm_addmv( |
| func, |
| V, |
| m1, |
| m2, |
| beta=1, |
| transpose_out=t4, |
| activation=activation, |
| ) |
| |
| @precisionOverride( |
| { |
| torch.float: 1e-4, |
| torch.half: 1e-1, |
| } |
| ) |
| @dtypes(torch.float32, torch.half) |
| def test_addmm(self, device, dtype): |
| self._test_addmm_impl(torch.addmm, None, device, dtype) |
| |
| @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4}) |
| @dtypes(torch.bfloat16, torch.half, torch.float) |
| def test_addmv(self, device, dtype): |
| # have to use torch.randn(...).to(bfloat16) instead of |
| # torch.randn(..., dtype=bfloat16). randn does not support |
| # bfloat16 yet. |
| # "*0.2" to reduce errors for low precision |
| ts = [ |
| 0.2 * torch.randn(50, device=device).to(dtype), |
| 0.2 * torch.randn(1, device=device).to(dtype).expand(50), |
| ] |
| vs = [ |
| 0.2 * torch.randn(100, device=device).to(dtype), |
| 0.2 |
| * torch.ones(1, device=device) |
| .to(dtype) |
| .expand(100), # to reduce errors for low precision |
| ] |
| ms = [ |
| # 0d |
| 0.2 |
| * torch.ones((), device=device) |
| .to(dtype) |
| .expand(50, 100), # to reduce errors for low precision |
| # 1d |
| 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), |
| # this initialization reduces errors for low precision for broadcasted matrices |
| # by making sure that intermediate and result values are exactly representable |
| # in low precision type |
| 0.2 |
| * torch.randint(3, (50, 1), dtype=torch.float, device=device) |
| .to(dtype) |
| .expand(50, 100), |
| # 2d |
| 0.2 * torch.randn((50, 100), device=device).to(dtype), |
| 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), |
| ] |
| for m, v, t in itertools.product(ms, vs, ts): |
| self._test_addmm_addmv(torch.addmv, t, m, v) |
| # Test beta=0, t=nan |
| t = torch.full((50,), math.nan, device=device).to(dtype) |
| for m, v in itertools.product(ms, vs): |
| self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) |
| |
| @dtypes( |
| torch.half, |
| torch.float32, |
| ) |
| def test_mm(self, device, dtype): |
| def _test_mm(n, m, p, dtype, genf): |
| # helper function |
| def matrixmultiply(mat1, mat2): |
| n = mat1.size(0) |
| m = mat1.size(1) |
| p = mat2.size(1) |
| dtype_ = torch.float if dtype == torch.half else dtype |
| if dtype == torch.half: |
| mat1 = mat1.float() |
| mat2 = mat2.float() |
| res = torch.zeros(n, p, dtype=dtype_, device=device) |
| for i, j in iter_indices(res): |
| res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) |
| return res.half() if dtype == torch.half else res |
| |
| # contiguous case |
| mat1 = genf(n, m) |
| mat2 = genf(m, p) |
| res = torch.mm(mat1, mat2) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # non contiguous case 1 |
| mat1 = genf(n, m) |
| mat2 = genf(p, m).t() |
| res = torch.mm(mat1, mat2) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # non contiguous case 2 |
| mat1 = genf(m, n).t() |
| mat2 = genf(m, p) |
| res = torch.mm(mat1, mat2) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # non contiguous case 3 |
| mat1 = genf(m, n).t() |
| mat2 = genf(p, m).t() |
| res = torch.mm(mat1, mat2) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # test with zero stride |
| mat1 = genf(n, m) |
| mat2 = genf(m, 1).expand(m, p) |
| res = torch.mm(mat1, mat2) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # explicitly exercise the _out variant in torch.mm(). |
| # contiguous case |
| mat1 = genf(n, m) |
| mat2 = genf(m, p) |
| res = genf(n, p) |
| torch.mm(mat1, mat2, out=res) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| # explicitly exercise the _out variant in torch.mm(). |
| # non contiguous case 3 |
| mat1 = genf(m, n).t() |
| mat2 = genf(p, m).t() |
| res = genf(n, p) |
| torch.mm(mat1, mat2, out=res) |
| |
| res2 = matrixmultiply(mat1, mat2) |
| self.assertEqual(res, res2) |
| |
| def genf_int(x, y): |
| return torch.randint(0, 100, (x, y), dtype=dtype, device=device) |
| |
| def genf_bfloat(x, y): |
| return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 |
| |
| def genf_float(x, y): |
| return torch.randn(x, y, dtype=dtype, device=device) |
| |
| def genf_Half(x, y): |
| return torch.randn(x, y, dtype=dtype, device=device) |
| |
| for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: |
| if (dtype == torch.int32) or (dtype == torch.int64): |
| genf = genf_int |
| elif dtype == torch.bfloat16: |
| genf = genf_bfloat |
| elif dtype == torch.half: |
| genf = genf_Half |
| else: |
| genf = genf_float |
| |
| _test_mm(n, m, p, dtype, genf) |
| |
| @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) |
| @dtypes(torch.float32, torch.bfloat16, torch.half) |
| def test_bmm(self, device, dtype): |
| batch_sizes = [1, 10] |
| M, N, O = 23, 15, 12 |
| numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 |
| |
| def invert_perm(p): |
| d = {x: i for i, x in enumerate(p)} |
| return (d[0], d[1], d[2]) |
| |
| def generate_inputs(num_batches): |
| # transposed tensors |
| for perm1, perm2 in itertools.product( |
| itertools.permutations((0, 1, 2)), repeat=2 |
| ): |
| b1 = make_tensor( |
| (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1 |
| ) |
| b2 = make_tensor( |
| (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1 |
| ) |
| b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) |
| b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) |
| yield b1, b2 |
| # broadcasting tensors |
| for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): |
| shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) |
| shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) |
| b1 = make_tensor( |
| shape1, dtype=dtype, device=device, low=-0.1, high=0.1 |
| ).expand(num_batches, M, N) |
| b2 = make_tensor( |
| shape2, dtype=dtype, device=device, low=-0.1, high=0.1 |
| ).expand(num_batches, N, O) |
| yield b1, b2 |
| # zero-sized tensors |
| for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): |
| shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) |
| shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) |
| b1 = torch.randn(shape1, dtype=dtype, device=device) |
| b2 = torch.randn(shape2, dtype=dtype, device=device) |
| yield b1, b2 |
| |
| for num_batches in batch_sizes: |
| for (b1, b2), perm3 in itertools.product( |
| generate_inputs(num_batches), itertools.permutations((0, 1, 2)) |
| ): |
| res1 = torch.bmm(b1, b2) |
| res2 = ( |
| torch.full( |
| (num_batches, M, O), math.nan, dtype=dtype, device=device |
| ) |
| .permute(perm3) |
| .contiguous() |
| .permute(invert_perm(perm3)) |
| ) |
| torch.bmm(b1, b2, out=res2) |
| expect = torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() |
| ).to(device=device, dtype=dtype) |
| self.assertEqual(expect, res1) |
| self.assertEqual(expect, res2) |
| |
| if self.device_type == "cuda": |
| # check that mixed arguments are rejected |
| self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) |
| self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) |
| self.assertRaises( |
| RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()) |
| ) |
| |
| def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): |
| getattr(out_tensor, func + "_")(b1, b2) |
| self.assertEqual(out_tensor, ref) |
| res3 = out_tensor.clone() |
| |
| with self.assertWarnsOnceRegex( |
| UserWarning, f"This overload of {func}_ is deprecated" |
| ): |
| getattr(out_tensor, func + "_")(1, b1, b2) |
| self.assertEqual(out_tensor, ref * 2), |
| getattr(res3, func + "_")(b1, b2, beta=1) |
| self.assertEqual(out_tensor, res3) |
| |
| with self.assertWarnsOnceRegex( |
| UserWarning, f"This overload of {func}_ is deprecated" |
| ): |
| getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2) |
| self.assertEqual(out_tensor, ref * 2.5) |
| getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5) |
| self.assertEqual(out_tensor, res3) |
| |
| with self.assertWarnsOnceRegex( |
| UserWarning, f"This overload of {func} is deprecated" |
| ): |
| self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) |
| |
| res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5) |
| self.assertEqual(res4, ref * 3), |
| |
| nan = torch.full_like(out_tensor, math.nan) |
| res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) |
| self.assertEqual(res5, ref) |
| |
| if b1.is_complex(): |
| res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j) |
| self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref) |
| else: |
| res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5) |
| self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref) |
| |
| res7 = torch.full_like(out_tensor, math.nan) |
| getattr(torch, func)(nan, b1, b2, beta=0, out=res7) |
| self.assertEqual(res7, ref) |
| |
| @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) |
| @dtypes(torch.float32, torch.bfloat16, torch.half) |
| def test_addbmm(self, device, dtype): |
| num_batches = 2 |
| M, N, O = 16, 17, 18 |
| |
| is_supported = True |
| |
| if not is_supported: |
| b1 = make_tensor( |
| (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 |
| ) |
| b2 = make_tensor( |
| (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 |
| ) |
| t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", |
| lambda: torch.addbmm(t, b1, b2), |
| ) |
| return |
| |
| def invert_perm(p): |
| d = {x: i for i, x in enumerate(p)} |
| return (d[0], d[1], d[2]) |
| |
| def generate_tensor(): |
| numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 |
| # transposed tensors |
| for perm1, perm2 in itertools.product( |
| itertools.permutations((0, 1, 2)), repeat=2 |
| ): |
| for perm3 in itertools.permutations((0, 1)): |
| b1 = ( |
| make_tensor( |
| (num_batches, M, N), |
| dtype=dtype, |
| device=device, |
| low=-1, |
| high=1, |
| ) |
| * 0.1 |
| ) |
| b2 = ( |
| make_tensor( |
| (num_batches, N, O), |
| dtype=dtype, |
| device=device, |
| low=-1, |
| high=1, |
| ) |
| * 0.1 |
| ) |
| b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) |
| b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) |
| ref = ( |
| torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() |
| @ b2.to(numpy_dtype).cpu().numpy() |
| ) |
| .to(device=device, dtype=dtype) |
| .sum(0) |
| ) |
| out_tensor = ( |
| torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) |
| ) |
| yield b1, b2, ref, out_tensor |
| # broadcasting tensors |
| for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): |
| shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) |
| shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) |
| b1 = ( |
| make_tensor( |
| shape1, dtype=dtype, device=device, low=-1, high=1 |
| ).expand(num_batches, M, N) |
| * 0.1 |
| ) |
| b2 = ( |
| make_tensor( |
| shape2, dtype=dtype, device=device, low=-1, high=1 |
| ).expand(num_batches, N, O) |
| * 0.1 |
| ) |
| ref = ( |
| torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() |
| @ b2.to(numpy_dtype).cpu().numpy() |
| ) |
| .to(device=device, dtype=dtype) |
| .sum(0) |
| ) |
| out_tensor = torch.zeros_like(ref) |
| yield b1, b2, ref, out_tensor |
| # zero-sized tensors |
| for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): |
| shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) |
| shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) |
| b1 = ( |
| make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) |
| * 0.1 |
| ) |
| b2 = ( |
| make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) |
| * 0.1 |
| ) |
| ref = ( |
| torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() |
| @ b2.to(numpy_dtype).cpu().numpy() |
| ) |
| .to(device=device, dtype=dtype) |
| .sum(0) |
| ) |
| out_tensor = torch.zeros_like(ref) |
| yield b1, b2, ref, out_tensor |
| |
| for b1, b2, ref, out_tensor in generate_tensor(): |
| self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) |
| |
| @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) |
| @dtypes(torch.float32, torch.bfloat16, torch.half) |
| def test_baddbmm(self, device, dtype): |
| num_batches = 10 |
| M, N, O = 12, 8, 50 |
| |
| def invert_perm(p): |
| d = {x: i for i, x in enumerate(p)} |
| return (d[0], d[1], d[2]) |
| |
| def generate_tensor(): |
| numpy_dtype = ( |
| dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 |
| ) |
| # transposed tensors |
| for perm1, perm2, perm3 in itertools.product( |
| itertools.permutations((0, 1, 2)), repeat=3 |
| ): |
| b1 = make_tensor( |
| (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 |
| ) |
| b2 = make_tensor( |
| (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 |
| ) |
| b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) |
| b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) |
| ref = torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() |
| ).to(device=device, dtype=dtype) |
| out_tensor = torch.zeros_like(ref) |
| out_tensor = ( |
| out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) |
| ) |
| yield b1, b2, ref, out_tensor |
| # broadcasting tensors |
| for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): |
| shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) |
| shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) |
| b1 = make_tensor( |
| shape1, dtype=dtype, device=device, low=-1, high=1 |
| ).expand(num_batches, M, N) |
| b2 = make_tensor( |
| shape2, dtype=dtype, device=device, low=-1, high=1 |
| ).expand(num_batches, N, O) |
| ref = torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() |
| ).to(device=device, dtype=dtype) |
| out_tensor = torch.zeros_like(ref) |
| yield b1, b2, ref, out_tensor |
| # zero-sized tensors |
| for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): |
| shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) |
| shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) |
| b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) |
| b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) |
| ref = torch.from_numpy( |
| b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() |
| ).to(device=device, dtype=dtype) |
| out_tensor = torch.zeros_like(ref) |
| yield b1, b2, ref, out_tensor |
| |
| for b1, b2, ref, out_tensor in generate_tensor(): |
| self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) |
| |
| def test_tensordot(self, device): |
| a = torch.arange(60.0, device=device).reshape(3, 4, 5) |
| b = torch.arange(24.0, device=device).reshape(4, 3, 2) |
| c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() |
| cn = torch.from_numpy( |
| np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1])) |
| ) |
| self.assertEqual(c, cn) |
| |
| cout = torch.zeros((5, 2), device=device) |
| torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() |
| self.assertEqual(c, cout) |
| |
| a = torch.randn(2, 3, 4, 5, device=device) |
| b = torch.randn(4, 5, 6, 7, device=device) |
| c = torch.tensordot(a, b, dims=2).cpu() |
| cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) |
| |
| with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): |
| torch.tensordot(a, b, dims=-1) |
| |
| self.assertEqual(c, cn) |
| c = torch.tensordot(a, b).cpu() |
| cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) |
| self.assertEqual(c, cn) |
| |
| a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0) |
| an = torch.from_numpy( |
| np.tensordot( |
| np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0 |
| ) |
| ) |
| self.assertEqual(a, an) |
| |
| @dtypes(torch.float) |
| @precisionOverride({torch.float32: 1e-4}) |
| def test_1_sized_with_0_strided(self, device, dtype): |
| a = make_tensor((8, 1, 64), dtype=dtype, device=device) |
| a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) |
| b = make_tensor((8, 64, 512), dtype=dtype, device=device) |
| b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) |
| res = torch.bmm(a_strided, b_strided) |
| expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to( |
| device=device, dtype=dtype |
| ) |
| self.assertEqual(expect, res) |
| |
| def _select_broadcastable_dims(self, dims_full=None): |
| # select full dimensionality |
| if dims_full is None: |
| dims_full = [] |
| ndims = random.randint(1, 4) |
| dims_full = [random.randint(1, 8) for _ in range(ndims)] |
| else: |
| ndims = len(dims_full) |
| |
| # select actual dimensions for ops: |
| # larger: full ndims, individual sizes may be reduced |
| # smaller: possibly reduced ndims, sizes may be reduced |
| smaller_ndims = random.randint(1, ndims) |
| dims_small = [] |
| dims_large = [] |
| for i in range(ndims - 1, -1, -1): |
| j = random.randint(1, 3) |
| if j == 1: # no reduced singleton dimension |
| ds = dims_full[i] |
| dl = dims_full[i] |
| elif j == 2: # larger may have reduced singleton dimension |
| ds = dims_full[i] |
| dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] |
| elif j == 3: # smaller may have reduced singleton dimension |
| ds = 1 |
| dl = dims_full[i] |
| dims_large = [dl] + dims_large |
| if len(dims_small) < smaller_ndims: |
| dims_small = [ds] + dims_small |
| return (dims_small, dims_large, dims_full) |
| |
| def test_broadcast_fused_matmul(self, device): |
| fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] |
| |
| for fn in fns: |
| batch_dim = random.randint(1, 8) |
| n_dim = random.randint(1, 8) |
| m_dim = random.randint(1, 8) |
| p_dim = random.randint(1, 8) |
| |
| def dims_full_for_fn(): |
| if fn == "baddbmm": |
| return ( |
| [batch_dim, n_dim, p_dim], |
| [batch_dim, n_dim, m_dim], |
| [batch_dim, m_dim, p_dim], |
| ) |
| elif fn == "addbmm": |
| return ( |
| [n_dim, p_dim], |
| [batch_dim, n_dim, m_dim], |
| [batch_dim, m_dim, p_dim], |
| ) |
| elif fn == "addmm": |
| return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) |
| elif fn == "addmv": |
| return ([n_dim], [n_dim, m_dim], [m_dim]) |
| elif fn == "addr": |
| return ([n_dim, m_dim], [n_dim], [m_dim]) |
| else: |
| raise AssertionError("unknown function") |
| |
| (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() |
| (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) |
| |
| t0_small = torch.randn(*t0_dims_small, device=device).float() |
| t1 = torch.randn(*t1_dims, device=device).float() |
| t2 = torch.randn(*t2_dims, device=device).float() |
| |
| t0_full = t0_small.expand(*t0_dims_full).to(device) |
| |
| fntorch = getattr(torch, fn) |
| r0 = fntorch(t0_small, t1, t2) |
| r1 = fntorch(t0_full, t1, t2) |
| self.assertEqual(r0, r1) |
| |
| @dtypes(torch.float32) |
| def test_strided_mm_bmm(self, device, dtype): |
| # Tests strided view case with stride smaller than corresponding dimension size |
| x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device) |
| new_shape = [2, 2, 2] |
| new_stride = [3, 1, 1] |
| sx = torch.as_strided(x, size=new_shape, stride=new_stride) |
| |
| torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 |
| np_fn = lambda x: np.matmul(x, x) # noqa: E731 |
| self.compare_with_numpy(torch_fn, np_fn, sx) |
| |
| torch_fn = lambda x: torch.mm(x, x) # noqa: E731 |
| self.compare_with_numpy(torch_fn, np_fn, sx[0]) |
| |
| def test_mm_empty_inputs_mixed_dtype_errors(self, device): |
| a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) |
| b = torch.randn(10, 20, dtype=torch.float32, device=device) |
| with self.assertRaisesRegex( |
| RuntimeError, "expected .* and .* to have the same dtype, but got:" |
| ): |
| torch.mm(a, b) |
| |
| def test_matmul_45724(self, device): |
| # https://github.com/pytorch/pytorch/issues/45724 |
| a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) |
| b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) |
| c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) |
| cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half() |
| torch.matmul(a, b, out=c) |
| self.assertEqual(c, cpu_result) |
| |
| @dtypes( |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| torch.float16, |
| torch.float32, |
| torch.float64, |
| ) |
| def test_baddbmm_input_dtypes_compatibility(self, device, dtype): |
| batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) |
| batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) |
| input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) |
| if dtype != torch.float32: |
| with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): |
| y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) |
| else: |
| out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) |
| y_ref = torch.bmm(batch1, batch2) |
| y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) |
| self.assertEqual(out, y_ref) |
| |
| @dtypes(torch.float) |
| def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): |
| for shape in [[3, 2, 2], [2, 20, 20]]: |
| mat1, mat2 = ( |
| torch.randn(shape, dtype=dtype, device=device) for _ in range(2) |
| ) |
| inputs = [ |
| torch.randn(shape, dtype=dtype, device=device), |
| torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), |
| ] |
| outs = [ |
| None, |
| torch.randn(shape, dtype=dtype, device=device), |
| torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), |
| ] |
| options = itertools.product(inputs, outs) |
| for input, out in options: |
| y_ref = torch.bmm(mat1, mat2) |
| y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) |
| self.assertEqual(y_ref, y) |
| |
| @dtypes(torch.float) |
| def test_addmm_sizes(self, device, dtype): |
| for m in [0, 1, 25]: |
| for n in [0, 1, 10]: |
| for k in [0, 1, 8]: |
| M = torch.randn(n, m, device=device).to(dtype) |
| m1 = torch.randn(n, k, device=device).to(dtype) |
| m2 = torch.randn(k, m, device=device).to(dtype) |
| self._test_addmm_addmv(torch.addmm, M, m1, m2) |
| |
| m1 = torch.randn(n, k + 1, device=device).to(dtype) |
| m2 = torch.randn(k, m, device=device).to(dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| f"{n}x{k + 1}.*{k}x{m}", |
| lambda: torch.addmm(M, m1, m2), |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2) |
| ) |
| |
| @precisionOverride( |
| { |
| torch.double: 1e-8, |
| torch.float: 1e-4, |
| torch.bfloat16: 5e-2, |
| torch.half: 5e-2, |
| torch.cfloat: 1e-4, |
| torch.cdouble: 1e-8, |
| } |
| ) |
| @dtypes(torch.float32, torch.bfloat16, torch.half) |
| def test_addmm_gelu(self, device, dtype): |
| self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) |
| |
| @precisionOverride( |
| { |
| torch.double: 1e-8, |
| torch.float: 1e-4, |
| torch.bfloat16: 5e-2, |
| torch.half: 5e-2, |
| torch.cfloat: 1e-4, |
| torch.cdouble: 1e-8, |
| } |
| ) |
| @dtypes(torch.float32, torch.bfloat16, torch.half) |
| def test_addmm_relu(self, device, dtype): |
| self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) |
| |
| @dtypes(torch.float, torch.bfloat16, torch.half) |
| def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): |
| # tests (o, s)*(s). o is output size, s is summed size. |
| o = 5 |
| s = 3 |
| a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) |
| x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) |
| y_data = torch.ones(o, device=device, dtype=dtype) |
| control = torch.tensor( |
| [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype |
| ) |
| |
| def _test(row_major, incx, incy, lda_tail): |
| if row_major: |
| a_storage = torch.full( |
| (o, s + lda_tail), float("nan"), device=device, dtype=dtype |
| ) |
| else: |
| a_storage = torch.full( |
| (s, o + lda_tail), float("nan"), device=device, dtype=dtype |
| ).permute(1, 0) |
| a = a_storage[:o, :s].copy_(a_data) |
| |
| x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype) |
| x = x_storage[:, 0].copy_(x_data) |
| |
| y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype) |
| y = y_storage[:, 0].copy_(y_data) |
| |
| self._test_addmm_addmv(torch.addmv, y, a, x) |
| |
| for row_major, incx, incy, lda_tail in itertools.product( |
| (False, True), (1, 2), (1, 2), (0, 1) |
| ): |
| _test(row_major, incx, incy, lda_tail) |
| |
| @precisionOverride( |
| { |
| torch.double: 1e-8, |
| torch.float: 1e-4, |
| torch.bfloat16: 0.6, |
| torch.half: 1e-1, |
| torch.cfloat: 1e-4, |
| torch.cdouble: 1e-8, |
| } |
| ) |
| @dtypes(torch.bfloat16, torch.half, torch.float32) |
| def test_corner_cases_of_cublasltmatmul(self, device, dtype): |
| # common case |
| M = torch.randn(128, device=device).to(dtype) |
| m1 = torch.randn(2048, 2400, device=device).to(dtype) |
| m2 = torch.randn(128, 2400, device=device).to(dtype) |
| torch.nn.functional.linear(m1, m2, M) |
| # Ntrans_B has ld >> rows |
| m1 = torch.rand([128, 2400]).to(dtype).to(device).t() |
| m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] |
| M = torch.rand([128]).to(dtype).to(device) |
| torch.addmm(M, m2.t(), m1) |
| # trans_A has ld >> rows |
| m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() |
| m2 = torch.randn(2048, 2400, device=device).to(dtype) |
| M = torch.rand([128]).to(dtype).to(device) |
| torch.addmm(M, m2, m1) |
| # large tensor dim > 65535 |
| M = torch.randn(16, device=device).to(dtype) |
| m1 = torch.randn(32, 131071, device=device).to(dtype) |
| m2 = torch.randn(16, 131071, device=device).to(dtype) |
| torch.nn.functional.linear(m1, m2, M) |
| |
| def test_blas_empty(self, device): |
| def fn(torchfn, *args, test_out=False, **kwargs): |
| def call_torch_fn(*args, **kwargs): |
| return torchfn( |
| *tuple( |
| torch.randn(shape, device=device) |
| if isinstance(shape, tuple) |
| else shape |
| for shape in args |
| ), |
| **kwargs, |
| ) |
| |
| result = call_torch_fn(*args, **kwargs) |
| if not test_out: |
| return result |
| else: |
| out = torch.full_like(result, math.nan) |
| out1 = call_torch_fn(*args, **kwargs, out=out) |
| return out |
| |
| # mm, addmm |
| self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) |
| self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) |
| self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) |
| self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) |
| self.assertEqual( |
| torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)) |
| ) |
| self.assertEqual( |
| torch.zeros((5, 6), device=device), |
| fn(torch.mm, (5, 0), (0, 6), test_out=True), |
| ) |
| |
| self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) |
| self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape) |
| t = torch.randn((5, 6), device=device) |
| self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) |
| self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) |
| |
| # mv, addmv |
| self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) |
| self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) |
| self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) |
| self.assertEqual( |
| torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True) |
| ) |
| |
| self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) |
| t = torch.randn((3,), device=device) |
| self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) |
| self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) |
| |
| # bmm, baddbmm |
| self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) |
| self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) |
| self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) |
| self.assertEqual( |
| torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)) |
| ) |
| self.assertEqual( |
| torch.zeros((3, 5, 6), device=device), |
| fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True), |
| ) |
| |
| self.assertEqual( |
| (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape |
| ) |
| self.assertEqual( |
| (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape |
| ) |
| self.assertEqual( |
| (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape |
| ) |
| self.assertEqual( |
| (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape |
| ) |
| c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) |
| self.assertEqual( |
| -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2) |
| ) # Issue #33467 |
| self.assertEqual( |
| -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True) |
| ) # Issue #33467 |
| |
| # addbmm |
| self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) |
| self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) |
| t = torch.randn((5, 6), device=device) |
| self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) |
| self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) |
| |
| # matmul |
| self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,))) |
| self.assertEqual( |
| torch.tensor(0.0, device=device), |
| fn(torch.matmul, (0,), (0,), test_out=True), |
| ) |
| self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) |
| self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) |
| self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) |
| self.assertEqual( |
| torch.zeros((5, 3, 4), device=device), |
| fn(torch.matmul, (5, 3, 0), (5, 0, 4)), |
| ) |
| self.assertEqual( |
| torch.zeros((5, 3, 4), device=device), |
| fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True), |
| ) |
| |
| # dot |
| self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,))) |
| self.assertEqual( |
| torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True) |
| ) |
| |
| def test_large_bmm_backward(self, device): |
| A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT |
| B = torch.randn([1, 1024, 65536], device=device, requires_grad=True) |
| G = torch.randn([1024, 2, 65536], device=device) |
| |
| # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM |
| (A @ B).backward(G) |
| |
| def test_large_bmm_mm_backward(self, device): |
| A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT |
| B = torch.randn([1024, 65536], device=device, requires_grad=True) |
| G = torch.randn([1024, 2, 65536], device=device) |
| |
| # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM |
| (A @ B).backward(G) |
| |
| def check_single_matmul(self, x, y): |
| def assertEqual(answer, expected): |
| if x.dtype.is_floating_point or x.dtype.is_complex: |
| k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix |
| self.assertEqual( |
| answer, |
| expected, |
| msg=f"{x.shape} x {y.shape} = {answer.shape}", |
| atol=k * 5e-5, |
| rtol=1e-4, |
| ) |
| else: |
| self.assertEqual( |
| answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}" |
| ) |
| |
| # test x @ y |
| expected = np.matmul(x.cpu(), y.cpu()) |
| ans = torch.matmul(x, y) |
| self.assertTrue(ans.is_contiguous()) |
| assertEqual(ans, expected) |
| |
| # test out |
| out = torch.empty_like(ans) |
| ans = torch.matmul(x, y, out=out) |
| self.assertIs(ans, out) |
| self.assertTrue(ans.is_contiguous()) |
| assertEqual(ans, expected) |
| |
| def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): |
| """ |
| Generates sequences of tuples (x, y) of with size(x) = x_dim and |
| size(y) <= y_dim that are compatible wrt. matmul |
| """ |
| assert x_dim >= 1 |
| assert y_dim >= 2 |
| x = x_dim |
| for y in range(1, y_dim + 1): |
| for batch, mn in product( |
| product(range(batch_size), repeat=max(x - 2, y - 2, 0)), |
| product(range(matrix_size), repeat=min(y, 2)), |
| ): |
| if x == 1: |
| size_x = mn[:1] |
| size_y = batch + mn |
| yield size_x, size_y |
| else: |
| for k in range(matrix_size): |
| size_x = (k,) + mn[:1] |
| if x > 2: |
| size_x = batch[-(x - 2) :] + size_x |
| size_y = mn |
| if y > 2: |
| size_y = batch[-(y - 2) :] + size_y |
| yield size_x, size_y |
| |
| @dtypes(torch.float) |
| def test_matmul_small_brute_force_1d_Nd(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for (size_x, size_y), nctg_x, nctg_y in product( |
| self.gen_sizes_matmul(1), (True, False), (True, False) |
| ): |
| x = make_arg(size_x, noncontiguous=nctg_x) |
| y = make_arg(size_y, noncontiguous=nctg_y) |
| self.check_single_matmul(x, y) |
| |
| @dtypes(torch.float) |
| def test_matmul_small_brute_force_2d_Nd(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for (size_x, size_y), nctg_x, nctg_y in product( |
| self.gen_sizes_matmul(2), (True, False), (True, False) |
| ): |
| x = make_arg(size_x, noncontiguous=nctg_x) |
| y = make_arg(size_y, noncontiguous=nctg_y) |
| self.check_single_matmul(x, y) |
| |
| @dtypes(torch.float) |
| def test_matmul_small_brute_force_3d_Nd(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for (size_x, size_y), nctg_x, nctg_y in product( |
| self.gen_sizes_matmul(3), (True, False), (True, False) |
| ): |
| x = make_arg(size_x, noncontiguous=nctg_x) |
| y = make_arg(size_y, noncontiguous=nctg_y) |
| self.check_single_matmul(x, y) |
| |
| @dtypes(torch.float) |
| def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): |
| a = torch.empty( |
| (256, 512), device=device, dtype=dtype, requires_grad=True |
| ).unsqueeze(0) |
| b = torch.empty( |
| (4, 128, 512), device=device, dtype=dtype, requires_grad=True |
| ).transpose(-1, -2) |
| c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) |
| |
| torch.matmul(a.detach(), b.detach(), out=c) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "functions with out=... arguments don't support automatic differentiation", |
| ): |
| torch.matmul(a, b, out=c) |
| |
| with torch.no_grad(): |
| torch.matmul(a, b, out=c) |
| |
| |
| instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu") |
| |
| if __name__ == "__main__": |
| run_tests() |