Adds generic device tests to test_autograd.py (#26248)
Summary:
- Adds new decorators for skipping on ROCm, skipping on MKL, running only on the CPU and running only on CUDA
- Makes decorator skip semantics consistent
- Adds CUDA default stream requirement to MAGMA decorator
- Creates TestAutogradDeviceType
Note this PR originally moved test_cdist, but moving it caused failures in CI. There may be an undiagnosed issue with cdist or the test. The issue does not reproduce locally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26248
Test Plan: Change is to tests themselves.
Differential Revision: D17410386
Pulled By: mruberry
fbshipit-source-id: 8459df44f2a00f0e71680fbe713587a01d4b0300
diff --git a/test/common_device_type.py b/test/common_device_type.py
index 5563f75..2b47282 100644
--- a/test/common_device_type.py
+++ b/test/common_device_type.py
@@ -2,7 +2,8 @@
from functools import wraps
import unittest
import torch
-from common_utils import TestCase
+from common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
+ skipCUDANonDefaultStreamIf
# Note: Generic Device-Type Testing
#
@@ -109,7 +110,7 @@
def setUpClass(cls):
# has_magma shows up after cuda is initialized
torch.ones(1).cuda()
- cls.has_magma = torch.cuda.has_magma
+ cls.no_magma = not torch.cuda.has_magma
# Adds available device-type-specific test base classes
@@ -169,11 +170,11 @@
scope[class_name] = device_type_test_class
-# Decorator that specifies a test dependency.
+# Decorator that skips a test if the given condition is true.
# Notes:
-# (1) Dependencies stack. Multiple dependencies are all evaluated.
-# (2) Dependencies can either be bools or strings. If a string the
-# test base must have defined the corresponding attribute to be True
+# (1) Skip conditions stack.
+# (2) Skip conditions can be bools or strings. If a string the
+# test base must have defined the corresponding attribute to be False
# for the test to run. If you want to use a string argument you should
# probably define a new decorator instead (see below).
# (3) Prefer the existing decorators to defining the 'device_type' kwarg.
@@ -189,32 +190,68 @@
@wraps(fn)
def dep_fn(slf, device, *args, **kwargs):
if self.device_type is None or self.device_type == slf.device_type:
- if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)):
+ if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (isinstance(self.dep, bool) and self.dep):
raise unittest.SkipTest(self.reason)
return fn(slf, device, *args, **kwargs)
return dep_fn
-# Specifies a CPU dependency.
+# Skips a test on CPU if the condition is true.
class skipCPUIf(skipIf):
def __init__(self, dep, reason):
super(skipCPUIf, self).__init__(dep, reason, device_type='cpu')
-# Specifies a CUDA dependency.
+# Skips a test on CUDA if the condition is true.
class skipCUDAIf(skipIf):
def __init__(self, dep, reason):
super(skipCUDAIf, self).__init__(dep, reason, device_type='cuda')
-# Specifies LAPACK as a CPU dependency.
+class onlyOn(object):
+
+ def __init__(self, device_type):
+ self.device_type = device_type
+
+ def __call__(self, fn):
+
+ @wraps(fn)
+ def only_fn(slf, device, *args, **kwargs):
+ if self.device_type != slf.device_type:
+ reason = "Only runs on {0}".format(self.device_type)
+ raise unittest.SkipTest(reason)
+
+ return fn(slf, device, *args, **kwargs)
+
+ return only_fn
+
+
+def onlyCPU(fn):
+ return onlyOn('cpu')(fn)
+
+
+def onlyCUDA(fn):
+ return onlyOn('cuda')(fn)
+
+
+# Skips a test on CPU if LAPACK is not available.
def skipCPUIfNoLapack(fn):
- return skipCPUIf(torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
+ return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
-# Specifies MAGMA as a CUDA dependency.
+# Skips a test on CPU if MKL is not available.
+def skipCPUIfNoMkl(fn):
+ return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)
+
+
+# Skips a test on CUDA if MAGMA is not available.
def skipCUDAIfNoMagma(fn):
- return skipCUDAIf('has_magma', "no MAGMA library detected")(fn)
+ return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn))
+
+
+# Skips a test on CUDA when using ROCm.
+def skipCUDAIfRocm(fn):
+ return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 6fcf840..9b082e2 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -34,6 +34,8 @@
exclude_tensor_method,
mask_not_all_zeros,
S)
+from common_device_type import (instantiate_device_type_tests, skipCUDAIfRocm,
+ onlyCUDA)
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@@ -682,47 +684,6 @@
"calculating the gradient of a sparse Tensor argument to mm is not supported."):
z.sum().backward()
- # NOTE: flaky on ROCm CI
- @skipIfRocm
- def test_sparse_ctor_getter_backward(self):
- # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test
- def test(size, sparse_dim, nnz, device):
- v_size = [nnz] + list(size[sparse_dim:])
- i = torch.rand(sparse_dim, nnz)
- i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
- i = i.to(torch.long)
-
- inp = torch.randn(v_size, requires_grad=True)
- other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True)[0]
- other = other.to(device)
-
- def fn(v):
- x = torch.sparse_coo_tensor(i, v, size, device=device)
- y = (x + other).coalesce()
- yv = y.values()
- new_v = yv.tanh()
- z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
- return z.coalesce().values()
-
- gradcheck(fn, (inp,))
- # FIXME: make gradgradcheck work.
- # gradgradcheck(fn, (inp,))
-
- # assert that _values is non-differentiable
- with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"):
- other.detach().requires_grad_()._values().backward(torch.ones_like(other._values()))
-
- devices = ['cpu']
-
- if torch.cuda.is_available():
- devices.append('cuda')
-
- for empty_i, empty_v, empty_nnz in product([True, False], repeat=3):
- sparse_size = [] if empty_i else [2, 1]
- dense_size = [1, 0, 2] if empty_v else [1, 2]
- nnz = 0 if empty_nnz else 5
- for device in devices:
- test(sparse_size + dense_size, len(sparse_size), nnz, device)
def test_multi_backward(self):
x = torch.randn(5, 5, requires_grad=True)
@@ -1712,19 +1673,6 @@
def test_sparse_gather_both_scalar(self):
self._test_sparse_gather((), (), 0)
- # autograd tests via common_method_invocations don't allow input tensors to
- # be sparse (RuntimeError: gradcheck expects all tensor inputs are dense when
- # check_sparse_nnz is set to False.)
- def test_sparse_mask_autograd(self):
- for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
- tensor = torch.randn(3, requires_grad=True, device=device)
- mask = torch.ones(3, device=device)
- mask[1] = 0
- mask = mask.to_sparse()
- converted = tensor.sparse_mask(mask).to_dense()
- converted.sum().backward()
- self.assertEqual(tensor.grad, mask.to_dense())
-
def test_gc_in_destructor(self):
"""
Previously, if a Function destructor triggered a garbage collection,
@@ -2064,65 +2012,6 @@
self._test_type_conversion_backward(lambda x: x.cuda(0))
self._test_type_conversion_backward(lambda x: x.cuda(1))
- def _test_pyscalar_conversions(self, t, integral_conv):
- # integral -> integral
- l = t(torch.zeros(1, 1, 1, dtype=torch.long))
- pyscalar = -12345
- l[0] = pyscalar
- self.assertEqual(integral_conv(l), pyscalar)
-
- # floating point -> floating point
- f = Variable(t(torch.randn(1, 1)))
- pyscalar = -12345.1
- f[0] = pyscalar
- self.assertEqual(float(f), pyscalar)
- f[0] = nan
- self.assertTrue(math.isnan(float(f)))
- f[0] = inf
- self.assertEqual(float(f), inf, allow_inf=True)
- f[0] = -inf
- self.assertEqual(float(f), -inf, allow_inf=True)
-
- # integral -> floating point
- # check we can convert something that loses precision
- pyscalar = 1234567890123456789
- self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
- l[0] = pyscalar
- self.assertEqual(float(l), float(pyscalar))
-
- # floating point -> integral
- f[0] = nan
- self.assertRaises(ValueError, lambda: integral_conv(f[0]))
- f[0] = inf
- self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
- f[0] = -inf
- self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
- f[0] = sys.float_info.max
- self.assertEqual(integral_conv(f), sys.float_info.max)
-
- # bool, nonzero
- def test_nonzero(tensor, value, expected):
- tensor[0] = value
- self.assertEqual(expected, bool(tensor))
- self.assertEqual(expected, True if tensor else False)
-
- test_nonzero(l, 0, False)
- test_nonzero(l, -2, True)
- test_nonzero(f, 0.0, False)
- test_nonzero(f, sys.float_info.min, True)
- test_nonzero(f, nan, bool(nan))
- test_nonzero(f, inf, bool(inf))
- test_nonzero(f, -inf, bool(-inf))
-
- def test_pyscalar_conversions(self):
- self._test_pyscalar_conversions(lambda x: x, lambda x: int(x))
- if sys.version_info[0] == 2:
- self._test_pyscalar_conversions(lambda x: x, lambda x: long(x))
- if torch.cuda.is_available():
- self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: int(x))
- if sys.version_info[0] == 2:
- self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: long(x))
-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_pin_memory(self):
x = torch.randn(2, 2, requires_grad=True)
@@ -2473,6 +2362,7 @@
lambda y, x: torch.trapz(y, x),
True, f_args_variable, f_args_tensor)
+
# skip this test if running on rocm, because in cdist
# we use __shfl_down_sync on CUDA for fast reduction
# and it gives incorrect results on rocm platform
@@ -2506,6 +2396,7 @@
_test_cdist_for_size((2, 3, 5))
_test_cdist_for_size((1, 2, 3))
+
def test_var_mean_differentiable(self):
dim = [2, 4]
keepdim = False
@@ -2746,22 +2637,6 @@
a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_()
gradcheck(lambda a: torch.pow(2, a), (a,))
- # test for backward in https://github.com/pytorch/pytorch/issues/15511
- def test_pdist_large(self):
- def func(x):
- return torch.pdist(x, p=2)
-
- devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
- for device in devices:
- # shape[0] should be able to be (roughly) arbitrarily large, but the kernel
- # is currently limited to smaller sizes (see issue above); this is just testing
- # a floor.
- shape = (1000, 1)
- x = torch.randn(shape, device=device).requires_grad_()
- output = torch.pdist(x, p=2)
- # just run a single backward, as gradcheck/gradgradcheck is expensive here
- output.sum().backward()
-
@skipIfNoLapack
def test_pinverse(self):
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
@@ -2993,29 +2868,6 @@
# test select on expanded input case
test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0)
- def _test_where_functional(self, t):
- x = Variable(t(torch.randn(5, 5)), requires_grad=True)
- y = Variable(t(torch.randn(5, 5)), requires_grad=True)
- cond = Variable(t(mask_not_all_zeros((5, 5))), requires_grad=False)
-
- def where(cond, x, y):
- return torch.where(cond, x, y)
-
- gradcheck(where, [cond, x, y], raise_exception=True)
- gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5)))])
-
- x = Variable(t(torch.randn(5, 1, 5)), requires_grad=True)
- y = Variable(t(torch.randn(5, 5, 1)), requires_grad=True)
- gradcheck(where, [cond, x, y], raise_exception=True)
- gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5, 5)))])
-
- def test_where_functional(self):
- self._test_where_functional(lambda t: t)
-
- @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
- def test_where_functional_cuda(self):
- self._test_where_functional(lambda t: t.cuda())
-
def _test_lerp_tensor_weights(self, cast):
def construct_inputs(*shapes):
start = cast(torch.randn(shapes[0])).requires_grad_()
@@ -3287,43 +3139,6 @@
d, = torch.autograd.grad(c, a, retain_graph=True, create_graph=True)
self.assertTrue(d.requires_grad)
- @staticmethod
- def _test_set_requires_grad_only_for_floats(self, cuda):
- dtypes = [torch.int64, torch.int32, torch.int16, torch.int8,
- torch.float, torch.double]
- if cuda:
- dtypes.append(torch.half)
-
- def f1(dt):
- a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu')
- a.requires_grad_()
-
- def f2(dt):
- a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu')
- a.requires_grad = True
-
- def f3(dt):
- torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu', requires_grad=True)
-
- for dt in dtypes:
- a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu')
- a.requires_grad = False # should always work
- a.requires_grad_(False)
-
- for f in [f1, f2, f3]:
- if dt.is_floating_point:
- f(dt)
- else:
- with self.assertRaisesRegex(RuntimeError, 'floating point',
- msg="dt: {} device: {}".format(a.dtype, a.device)):
- f(dt)
-
- @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
- def test_set_requires_grad_only_for_floats_cuda(self):
- self._test_set_requires_grad_only_for_floats(self, True)
-
- def test_set_requires_grad_only_for_floats(self):
- self._test_set_requires_grad_only_for_floats(self, False)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_rnn_backward_to_input_but_not_parameters_cuda(self):
@@ -3597,15 +3412,6 @@
# in the same thread recursively
DeepReentrant.apply(v).sum().backward()
- @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
- def test_advanced_indexing_backwards_large(self):
- # See https://github.com/pytorch/pytorch/issues/22843
- n = (1 << 16)
- x = torch.rand(n, 1, device='cuda', requires_grad=True)
- a = x[:, [0]]
- a.sum().backward()
- self.assertEqual(x.grad, torch.ones(n, 1, device='cuda'))
-
def test_reentrant_priority(self):
order = []
@@ -3900,5 +3706,184 @@
for test in method_tests():
add_test(*test)
+# Generic device type autograd tests.
+class TestAutogradDeviceType(TestCase):
+
+ # NOTE: flaky on ROCm CI
+ @skipCUDAIfRocm
+ def test_sparse_ctor_getter_backward(self, device):
+ # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test
+ def _test(size, sparse_dim, nnz, device):
+ v_size = [nnz] + list(size[sparse_dim:])
+ i = torch.rand(sparse_dim, nnz)
+ i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+ i = i.to(torch.long)
+
+ inp = torch.randn(v_size, requires_grad=True)
+ other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True)[0]
+ other = other.to(device)
+
+ def fn(v):
+ x = torch.sparse_coo_tensor(i, v, size, device=device)
+ y = (x + other).coalesce()
+ yv = y.values()
+ new_v = yv.tanh()
+ z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
+ return z.coalesce().values()
+
+ gradcheck(fn, (inp,))
+ # FIXME: make gradgradcheck work.
+ # gradgradcheck(fn, (inp,))
+
+ # assert that _values is non-differentiable
+ with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"):
+ other.detach().requires_grad_()._values().backward(torch.ones_like(other._values()))
+
+ for empty_i, empty_v, empty_nnz in product([True, False], repeat=3):
+ sparse_size = [] if empty_i else [2, 1]
+ dense_size = [1, 0, 2] if empty_v else [1, 2]
+ nnz = 0 if empty_nnz else 5
+ _test(sparse_size + dense_size, len(sparse_size), nnz, device)
+
+ # autograd tests via common_method_invocations don't allow input tensors to
+ # be sparse (RuntimeError: gradcheck expects all tensor inputs are dense when
+ # check_sparse_nnz is set to False.)
+ def test_sparse_mask_autograd(self, device):
+ tensor = torch.randn(3, requires_grad=True, device=device)
+ mask = torch.ones(3, device=device)
+ mask[1] = 0
+ mask = mask.to_sparse()
+ converted = tensor.sparse_mask(mask).to_dense()
+ converted.sum().backward()
+ self.assertEqual(tensor.grad, mask.to_dense())
+
+ def test_pyscalar_conversions(self, device):
+ def _test_pyscalar_conversions(t, integral_conv):
+ # integral -> integral
+ l = t(torch.zeros(1, 1, 1, dtype=torch.long))
+ pyscalar = -12345
+ l[0] = pyscalar
+ self.assertEqual(integral_conv(l), pyscalar)
+
+ # floating point -> floating point
+ f = Variable(t(torch.randn(1, 1)))
+ pyscalar = -12345.1
+ f[0] = pyscalar
+ self.assertEqual(float(f), pyscalar)
+ f[0] = nan
+ self.assertTrue(math.isnan(float(f)))
+ f[0] = inf
+ self.assertEqual(float(f), inf, allow_inf=True)
+ f[0] = -inf
+ self.assertEqual(float(f), -inf, allow_inf=True)
+
+ # integral -> floating point
+ # check we can convert something that loses precision
+ pyscalar = 1234567890123456789
+ self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
+ l[0] = pyscalar
+ self.assertEqual(float(l), float(pyscalar))
+
+ # floating point -> integral
+ f[0] = nan
+ self.assertRaises(ValueError, lambda: integral_conv(f[0]))
+ f[0] = inf
+ self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
+ f[0] = -inf
+ self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
+ f[0] = sys.float_info.max
+ self.assertEqual(integral_conv(f), sys.float_info.max)
+
+ # bool, nonzero
+ def test_nonzero(tensor, value, expected):
+ tensor[0] = value
+ self.assertEqual(expected, bool(tensor))
+ self.assertEqual(expected, True if tensor else False)
+
+ test_nonzero(l, 0, False)
+ test_nonzero(l, -2, True)
+ test_nonzero(f, 0.0, False)
+ test_nonzero(f, sys.float_info.min, True)
+ test_nonzero(f, nan, bool(nan))
+ test_nonzero(f, inf, bool(inf))
+ test_nonzero(f, -inf, bool(-inf))
+
+
+ _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x))
+ if sys.version_info[0] == 2:
+ _test_pyscalar_conversions(lambda x: x.to(device), lambda x: long(x))
+
+ def test_set_requires_grad_only_for_floats(self, device):
+ dtypes = [torch.int64, torch.int32, torch.int16, torch.int8,
+ torch.float, torch.double]
+ if device == 'cuda':
+ dtypes.append(torch.half)
+
+ def f1(dt):
+ a = torch.ones(1, dtype=dt, device=device)
+ a.requires_grad_()
+
+ def f2(dt):
+ a = torch.ones(1, dtype=dt, device=device)
+ a.requires_grad = True
+
+ def f3(dt):
+ torch.ones(1, dtype=dt, device=device, requires_grad=True)
+
+ for dt in dtypes:
+ a = torch.ones(1, dtype=dt, device=device)
+ a.requires_grad = False # should always work
+ a.requires_grad_(False)
+
+ for f in [f1, f2, f3]:
+ if dt.is_floating_point:
+ f(dt)
+ else:
+ with self.assertRaisesRegex(RuntimeError, 'floating point',
+ msg="dt: {} device: {}".format(a.dtype, a.device)):
+ f(dt)
+
+ @onlyCUDA
+ def test_advanced_indexing_backwards_large(self, device):
+ # See https://github.com/pytorch/pytorch/issues/22843
+ n = (1 << 16)
+ x = torch.rand(n, 1, device=device, requires_grad=True)
+ a = x[:, [0]]
+ a.sum().backward()
+ self.assertEqual(x.grad, torch.ones(n, 1, device=device))
+
+ # test for backward in https://github.com/pytorch/pytorch/issues/15511
+ def test_pdist_large(self, device):
+ def func(x):
+ return torch.pdist(x, p=2)
+
+ # shape[0] should be able to be (roughly) arbitrarily large, but the kernel
+ # is currently limited to smaller sizes (see issue above); this is just testing
+ # a floor.
+ shape = (1000, 1)
+ x = torch.randn(shape, device=device).requires_grad_()
+ output = torch.pdist(x, p=2)
+ # just run a single backward, as gradcheck/gradgradcheck is expensive here
+ output.sum().backward()
+
+ def test_where_functional(self, device):
+ x = torch.randn(5, 5, device=device, requires_grad=True)
+ y = torch.randn(5, 5, device=device, requires_grad=True)
+ cond = mask_not_all_zeros((5, 5)).to(device=device)
+
+ def where(cond, x, y):
+ return torch.where(cond, x, y)
+
+ gradcheck(where, [cond, x, y], raise_exception=True)
+ gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)])
+
+ x = torch.randn(5, 1, 5, device=device, requires_grad=True)
+ y = torch.randn(5, 5, 1, device=device, requires_grad=True)
+ gradcheck(where, [cond, x, y], raise_exception=True)
+ gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)])
+
+
+instantiate_device_type_tests(TestAutogradDeviceType, globals())
+
if __name__ == '__main__':
run_tests()
diff --git a/test/test_torch.py b/test/test_torch.py
index 1006ec4..fcd8fd9 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8027,7 +8027,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_inverse(self, device):
from common_utils import random_fullrank_matrix_distinct_singular_value
@@ -8314,7 +8313,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_inverse_many_batches(self, device):
from common_utils import random_fullrank_matrix_distinct_singular_value
@@ -8330,7 +8328,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_pinverse(self, device):
def run_test(M):
# Testing against definition for pseudo-inverses
@@ -8355,7 +8352,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_matrix_rank(self, device):
a = torch.eye(10, device=device)
self.assertEqual(torch.matrix_rank(a).item(), 10)
@@ -8389,7 +8385,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_matrix_power(self, device):
def run_test(M, sign=1):
if sign == -1:
@@ -8449,7 +8444,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_det_logdet_slogdet(self, device):
def reference_slogdet(M):
if TEST_NUMPY:
@@ -8630,7 +8624,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_det_logdet_slogdet_batched(self, device):
from common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, random_square_matrix_of_rank)
@@ -8681,7 +8674,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_solve(self, device):
from common_utils import solve_test_helper
for (k, n) in zip([2, 3, 5], [3, 5, 7]):
@@ -8691,7 +8683,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_solve_batched(self, device):
from common_utils import solve_test_helper
@@ -8710,7 +8701,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_solve_batched_non_contiguous(self, device):
from numpy.linalg import solve
@@ -8724,7 +8714,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_solve_batched_many_batches(self, device):
from common_utils import solve_test_helper
@@ -8738,7 +8727,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_solve_batched_broadcasting(self, device):
from numpy.linalg import solve
@@ -8764,7 +8752,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_solve(self, device):
from common_utils import cholesky_solve_test_helper
for (k, n), upper in product(zip([2, 3, 5], [3, 5, 7]), [True, False]):
@@ -8774,7 +8761,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_solve_batched(self, device):
from common_utils import cholesky_solve_test_helper
@@ -8793,7 +8779,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_cholesky_solve_batched_non_contiguous(self, device):
from numpy.linalg import solve
@@ -8813,7 +8798,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_solve_batched_many_batches(self, device):
from common_utils import cholesky_solve_test_helper
@@ -8828,7 +8812,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_cholesky_solve_batched_broadcasting(self, device):
from numpy.linalg import solve
@@ -8857,7 +8840,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_inverse(self, device):
from common_utils import random_symmetric_pd_matrix
a = random_symmetric_pd_matrix(5).to(device)
@@ -8883,7 +8865,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_batched_many_batches(self, device):
from common_utils import random_symmetric_pd_matrix
@@ -8906,7 +8887,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky_batched(self, device):
from common_utils import random_symmetric_pd_matrix
@@ -8921,7 +8901,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_cholesky(self, device):
x = torch.rand(10, 10, device=device) + 1e-1
A = torch.mm(x, x.t())
@@ -10063,7 +10042,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lu_solve_batched_non_contiguous(self, device):
from numpy.linalg import solve
@@ -10082,7 +10060,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_lu_solve_batched_many_batches(self, device):
from common_utils import lu_solve_test_helper
@@ -10100,7 +10077,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lu_solve_batched_broadcasting(self, device):
from numpy.linalg import solve
@@ -10270,7 +10246,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_symeig(self, device):
from common_utils import random_symmetric_matrix
@@ -10313,7 +10288,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_svd(self, device):
def run_test(dims, some, compute_uv):
x = torch.randn(*dims, device=device)
@@ -10370,7 +10344,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_svd_no_singularvectors(self, device):
for size in [(5, 5), (5, 20), (20, 5)]:
a = torch.randn(*size, device=device)
@@ -10424,7 +10397,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_norm(self, device):
# full reduction
@@ -10455,7 +10427,6 @@
self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
@skipCUDAIfNoMagma
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_nuclear_norm_axes_small_brute_force(self, device):
def check_single_nuclear_norm(x, axes):
@@ -10533,7 +10504,6 @@
check_single_nuclear_norm(x, axes)
@skipCUDAIfNoMagma
- @skipCUDANonDefaultStreamIf(True)
def test_nuclear_norm_exceptions(self, device):
for lst in [], [1], [1, 2]:
for axes in (), (0,), (0, 1):
@@ -10560,7 +10530,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_geqrf(self, device):
a = torch.randn(5, 5, device=device)
b, c = torch.geqrf(a)
@@ -10571,7 +10540,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_triangular_solve(self, device):
from common_utils import triangular_solve_test_helper
for (k, n), (upper, unitriangular, transpose) in product(zip([2, 3, 5], [3, 5, 7]),
@@ -10586,7 +10554,6 @@
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_triangular_solve_batched_many_batches(self, device):
from common_utils import triangular_solve_test_helper
@@ -10610,7 +10577,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
def test_triangular_solve_batched_broadcasting(self, device):
from scipy.linalg import solve_triangular as tri_solve
@@ -10646,7 +10612,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_lstsq(self, device):
def cast_fn(tensor):
return tensor.to(device=device)
@@ -10769,7 +10734,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_qr(self, device):
def run_test(tensor_dims, some):
A = torch.randn(*tensor_dims, device=device)
@@ -11019,7 +10983,6 @@
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
- @skipCUDANonDefaultStreamIf(True)
def test_lapack_empty(self, device):
# FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.
# The LAPACK functions themselves generally do NOT work with zero sized dimensions, although