Removes torchtest, expands generic device testing (#26374)
Summary:
- Removes torchtest
- <s>Moves test_torch tests skipped on ROCm to generic device test class</s>
- Creates test_nn generic device test class
Next: adding dtypes to generic device testing framework.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26374
Test Plan: Change is to tests themselves.
Differential Revision: D17442218
Pulled By: mruberry
fbshipit-source-id: d7e4451d09fc9049478b35a7efb8bb580071e8c8
diff --git a/test/common_utils.py b/test/common_utils.py
index d4b2f71..6040a11 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -216,52 +216,6 @@
return run_test_function
-class torchtest():
- """Allows to generate and run per-device unittests.
-
- This decorator class allows to generate and run per-device unittest.
-
- Example:
-
- class _TestTorchMixin(torchtest):
-
- @torchtest.for_all_device_types()
- def test_zeros_like(self, device):
- expected = torch.zeros((100, 100,), device=device)
-
- Will execute:
-
- test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.'
- test_zeros_like_cpu (__main__.TestTorch) ... ok
- test_zeros_like_cuda (__main__.TestTorch) ... ok
-
- To work properly, test class should be inherited from `torchtest`.
- for_all_device_types decorator does not guarantee proper functionality in
- combination with other decorators.
-
- Please do not extend this decorator to support other cases (such as dtype,
- layouts, etc) without consulting with bigger group. Devices is the special
- case as build flags control additions/removals (see
- https://github.com/pytorch/pytorch/pull/23824 for the reference).
- """
- @classmethod
- def for_all_device_types(cls):
- def wrapper(fn):
- test_names = []
-
- for device in torch.testing.get_all_device_types():
- test_name = fn.__name__ + '_' + device
- assert not hasattr(cls, test_name), "Duplicated test name: " + test_name
- setattr(cls, test_name, _test_function(fn, device))
- test_names.append(test_name)
-
- @wraps(fn)
- def empty_test(*args, **kwargs):
- raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names)))
- return empty_test
- return wrapper
-
-
def skipIfNoLapack(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
diff --git a/test/test_nn.py b/test/test_nn.py
index 7eed6b8..87721a7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -35,6 +35,7 @@
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
ctcloss_reference, new_module_tests
+from common_device_type import instantiate_device_type_tests
from torch.nn import MultiheadAttention
@@ -966,34 +967,6 @@
with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'):
module(input)
- def _test_dropout(self, cls, cuda, input):
- p = 0.2
- device = torch.device("cuda") if cuda else torch.device("cpu")
- input = input.to(device).fill_(1 - p)
-
- module = cls(p)
- input_var = input.clone().requires_grad_()
- output = module(input_var)
- self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
- output.backward(input)
- self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
-
- module = cls(p, True)
- input_var = input.clone().requires_grad_()
- output = module(input_var + 0)
- self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
- output.backward(input)
- self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
-
- # check eval mode doesn't change anything
- for inplace in [True, False]:
- module = cls(p, inplace).eval()
- self.assertEqual(input, module(input))
-
- # Check that these don't raise errors
- module.__repr__()
- str(module)
-
def _test_alpha_dropout(self, cls, input):
mean = input.mean()
std = input.std()
@@ -3160,51 +3133,6 @@
gradcheck(func, [x])
gradgradcheck(func, [x])
- def test_Dropout(self):
- input = torch.Tensor(1000)
- self._test_dropout(nn.Dropout, False, input)
-
- def test_Dropout2d(self):
- b = random.randint(1, 5)
- w = random.randint(1, 5)
- h = random.randint(1, 5)
- num_features = 1000
- input = torch.Tensor(num_features, b, w, h)
- self._test_dropout(nn.Dropout2d, False, input)
-
- def test_Dropout3d(self):
- b = random.randint(1, 5)
- w = random.randint(1, 5)
- h = random.randint(1, 5)
- d = random.randint(1, 2)
- num_features = 1000
- input = torch.Tensor(num_features, b, d, w, h)
- self._test_dropout(nn.Dropout3d, False, input)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_Dropout_cuda(self):
- input = torch.Tensor(1000)
- self._test_dropout(nn.Dropout, True, input)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_Dropout2d_cuda(self):
- b = random.randint(1, 5)
- w = random.randint(1, 5)
- h = random.randint(1, 5)
- num_features = 1000
- input = torch.Tensor(num_features, b, w, h)
- self._test_dropout(nn.Dropout2d, True, input)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_Dropout3d_cuda(self):
- b = random.randint(1, 5)
- w = random.randint(1, 5)
- h = random.randint(1, 5)
- d = random.randint(1, 2)
- num_features = 1000
- input = torch.Tensor(num_features, b, d, w, h)
- self._test_dropout(nn.Dropout3d, True, input)
-
def test_AlphaDropout(self):
# generate random tensor with zero mean and unit std
input = torch.randn(5000)
@@ -3219,259 +3147,6 @@
input = torch.randn(num_features, b, d, w, h)
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
- def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float):
- # default case track_running_stats=False
- b, c = input.size(0), input.size(1)
- input_var = input.to(device=device, dtype=dtype).requires_grad_()
-
- IN = cls(c, eps=0).to(device, dtype)
-
- output = IN(input_var)
- out_reshaped = output.view(b * c, -1)
-
- mean = out_reshaped.mean(1)
- var = out_reshaped.var(1, unbiased=False)
-
- self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
-
- # check that eval mode doesn't change behavior
- grad_out = torch.randn_like(output)
- res1 = output.data.clone()
- output.backward(grad_out)
- grad1 = input_var.grad.data.clone()
-
- IN.eval()
- output = IN(input_var)
- input_var.grad = None
- output.backward(grad_out)
- res2 = output.data
- grad2 = input_var.grad.data
- self.assertEqual(res1, res2)
- self.assertEqual(grad1, grad2)
-
- # If track_running_stats=True and momentum=1, running_mean/var should be
- # equal to mean/var of the input (with unbias correction)
- IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
-
- output = IN(input_var)
-
- input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
- mean = input_reshaped.mean(1)
-
- input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
- var = input_reshaped.var(2, unbiased=True)[:, :]
-
- self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
-
- # in eval mode, adding X * std to a channel in input should make the
- # corresponding channel in output have mean X
- IN.eval()
- delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
- delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
- output = IN(input_var + delta)
- self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
-
- def _test_InstanceNorm_cuda_half(self, cls, input):
- # THNN
- input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
- m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
- thnn_output = m(input)
- thnn_output.sum().backward()
- thnn_input_grad = input.grad.data.clone()
- self.assertEqual(thnn_output.type(), input.type())
- # cuDNN
- if TEST_CUDNN:
- input.grad = None
- m = m.float()
- cudnn_output = m(input)
- cudnn_output.sum().backward()
- cudnn_input_grad = input.grad.data.clone()
- self.assertEqual(cudnn_output.type(), input.type())
- self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
- self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
-
- def test_InstanceNorm1d_general(self):
- b = random.randint(3, 5)
- c = random.randint(3, 5)
- d = random.randint(8, 10)
-
- input = torch.rand(b, c, d)
- self._test_InstanceNorm_general(nn.InstanceNorm1d, input, dtype=torch.float)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_InstanceNorm1d_general_cuda(self):
- b = random.randint(3, 5)
- c = random.randint(3, 5)
- d = random.randint(8, 10)
-
- input = torch.rand(b, c, d)
- self._test_InstanceNorm_general(nn.InstanceNorm1d, input, "cuda", torch.float)
- self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
-
- def test_InstanceNorm2d_general(self):
- b = random.randint(3, 5)
- c = random.randint(3, 5)
- w = random.randint(3, 6)
- h = random.randint(6, 8)
-
- input = torch.rand(b, c, h, w)
- self._test_InstanceNorm_general(nn.InstanceNorm2d, input, dtype=torch.float)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_InstanceNorm2d_general_cuda(self):
- b = random.randint(3, 5)
- c = random.randint(3, 5)
- w = random.randint(3, 6)
- h = random.randint(6, 8)
-
- input = torch.rand(b, c, h, w)
- self._test_InstanceNorm_general(nn.InstanceNorm2d, input, "cuda", torch.float)
- self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
-
- def test_InstanceNorm3d_general(self):
- b = random.randint(3, 5)
- c = random.randint(3, 5)
- w = random.randint(2, 5)
- h = random.randint(2, 5)
- d = random.randint(2, 5)
-
- input = torch.rand(b, c, h, w, d)
- self._test_InstanceNorm_general(nn.InstanceNorm3d, input, dtype=torch.float)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_InstanceNorm3d_general_cuda(self):
- b = random.randint(3, 5)
- c = random.randint(2, 5)
- w = random.randint(2, 5)
- h = random.randint(2, 5)
- d = random.randint(2, 5)
-
- input = torch.rand(b, c, h, w, d)
- self._test_InstanceNorm_general(nn.InstanceNorm3d, input, "cuda", torch.float)
- self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
-
- def _test_LayerNorm_general(self, device="cpu", dtype=torch.float):
- for i in range(2, 6):
- shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
- x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
- normalized_ndim = random.randint(1, i - 1) # inclusive
- normalized_shape = shape[-normalized_ndim:]
- unnormalized_shape = shape[:-normalized_ndim]
-
- # test that LN normalizes to mean 0 and stddev 1
- ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
- ln.weight.data.fill_(1)
- ln.bias.data.fill_(0)
- output = ln(x)
- out_reshaped = output.view(*(unnormalized_shape + [-1]))
- mean = out_reshaped.mean(-1)
- var = out_reshaped.var(-1, unbiased=False)
- self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
-
- # test that LN applies weight and bias correctly
- scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
- ln.weight.data.fill_(scale)
- ln.bias.data.fill_(bias)
- output = ln(x)
- out_reshaped = output.view(*(unnormalized_shape + [-1]))
- mean = out_reshaped.mean(-1)
- var = out_reshaped.var(-1, unbiased=False)
- self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
-
- bad_norm_shape_input_shape = {
- (): (),
- (2, 3): (3,),
- (2,): (1, 2, 3),
- (10,): (2, 3),
- 10: (2, 3),
- }
- for norm_shape, input_shape in bad_norm_shape_input_shape.items():
- ln = nn.LayerNorm(norm_shape)
- input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
- self.assertRaises(RuntimeError, lambda: ln(input))
-
- def _test_LayerNorm_cuda_half(self):
- input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
- m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
- output = m(input)
- output.sum().backward()
- self.assertEqual(output.type(), input.type())
-
- def test_LayerNorm_general(self):
- self._test_LayerNorm_general()
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_LayerNorm_general_cuda(self):
- self._test_LayerNorm_general("cuda")
- self._test_LayerNorm_cuda_half()
-
- def _test_GroupNorm_general(self, device="cpu", dtype=torch.float):
- good_shape_g = {
- (1, 2, 3, 4): 2,
- (2, 3, 10): 3,
- (3, 1, 1, 1, 2): 1,
- (2, 6, 4, 2, 2): 3,
- }
- for shape, g in good_shape_g.items():
- x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
- b = shape[0]
- c = shape[1]
-
- # test that GN normalizes to mean 0 and stddev 1
- gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
- gn.weight.data.fill_(1)
- gn.bias.data.fill_(0)
- output = gn(x)
- out_reshaped = output.view(b, g, -1)
- mean = out_reshaped.mean(-1)
- var = out_reshaped.var(-1, unbiased=False)
- self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
-
- # test that GN applies weight and bias correctly
- scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
- bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
- gn.weight.data.copy_(scale)
- gn.bias.data.copy_(bias)
- output = gn(x)
- out_reshaped = output.view(b, c, -1)
- out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
- out_normed_reshaped = out_normed.view(b, g, -1)
- mean = out_normed_reshaped.mean(-1)
- var = out_normed_reshaped.var(-1, unbiased=False)
- self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
- self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
-
- bad_shape_g = {
- (1, 2, 3, 4): 3,
- (2, 3, 10): 2,
- (3, 1, 1, 1, 2): 10,
- (2, 6, 4, 2, 2): 4,
- }
- for shape, g in bad_shape_g.items():
- gn = nn.GroupNorm(g, shape[1])
- input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
- self.assertRaises(RuntimeError, lambda: gn(input))
-
- def _test_GroupNorm_cuda_half(self):
- input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
- m = nn.GroupNorm(2, 4).to("cuda", torch.half)
- output = m(input)
- output.sum().backward()
- self.assertEqual(output.type(), input.type())
-
- def test_GroupNorm_general(self):
- self._test_GroupNorm_general(dtype=torch.float)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_GroupNorm_general_cuda(self):
- self._test_GroupNorm_general("cuda", torch.float)
- self._test_GroupNorm_cuda_half()
-
def test_pad(self):
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
@@ -3490,115 +3165,11 @@
self.assertRaisesRegex(RuntimeError, expected_err_msg,
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
- @staticmethod
- def _test_one_hot(self, use_cuda=False):
- device = torch.device('cuda' if use_cuda else 'cpu')
- with self.assertRaises(RuntimeError):
- torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
-
- with self.assertRaises(RuntimeError):
- torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
-
- t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
- expected = torch.tensor([[0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1],
- [0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0]], device=device)
- self.assertEqual(t, expected)
-
- t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
- expected = torch.tensor([[0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1],
- [0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0]], device=device)
- self.assertEqual(t, expected)
-
- t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
- expected = torch.tensor([[0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 1, 0],
- [0, 1, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0]], device=device)
- self.assertEqual(t, expected)
-
- t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
- expected = torch.tensor([[[0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1]],
- [[0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0]]], device=device)
- self.assertEqual(t, expected)
-
- t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
- expected = torch.tensor([0, 0, 0, 0, 1], device=device)
- self.assertEqual(t, expected)
-
- t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
- expected = torch.empty([4, 0, 100])
- self.assertEqual(t, expected)
-
- with self.assertRaises(RuntimeError):
- torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
-
- with self.assertRaises(RuntimeError):
- torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
-
- def test_one_hot(self):
- self._test_one_hot(self)
-
- @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
- def test_one_hot_cuda(self):
- self._test_one_hot(self, use_cuda=True)
-
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
- def test_nn_scalars(self):
- # One off tests to ensure scalars from nn.yaml are properly applied
- def verify_scalars(input, output):
- if input.dim() == 0:
- self.assertEqual((), output.shape)
- else:
- self.assertNotEqual((), output.shape)
- output.sum().backward()
- self.assertEqual(input.shape, input.grad.shape)
-
- devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
- for device in devices:
- for input_shape in [(5, 6), ()]:
- for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
- torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
- torch.nn.Tanh]:
- input = torch.randn(input_shape, device=device, requires_grad=True)
- m = module()
- output = m(input)
- verify_scalars(input, output)
-
- def test_nn_scalars_reductions(self):
- # One off tests to ensure scalars from nn.yaml are properly applied
- def verify_reduction_scalars(input, reduction, output):
- if reduction != 'none' or input.dim() == 0:
- self.assertEqual((), output.shape)
- else:
- self.assertNotEqual((), output.shape)
- output.sum().backward()
- self.assertEqual(input.shape, input.grad.shape)
-
- devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
- for device in devices:
- for input_shape in [(5, 6), ()]:
- for reduction in ['none', 'mean', 'sum']:
- for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
- torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
- input = torch.randn(input_shape, device=device, requires_grad=True)
- target = torch.empty(input_shape, device=device).random_(2)
- sigmoid = nn.Sigmoid()
-
- input = torch.randn(input_shape, device=device, requires_grad=True)
- m = module(reduction=reduction)
- output = m(sigmoid(input), target)
- verify_reduction_scalars(input, reduction, output)
-
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
def test_multihead_attention(self):
@@ -10090,6 +9661,374 @@
return transform_tensor, transform_ary, grid_ary
# end TestNN.test_affine_* helpers
+class GenericDeviceTypeHelpers(object):
+ def _test_dropout(self, cls, device, input):
+ p = 0.2
+ input = input.to(device).fill_(1 - p)
+
+ module = cls(p)
+ input_var = input.clone().requires_grad_()
+ output = module(input_var)
+ self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
+ output.backward(input)
+ self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
+
+ module = cls(p, True)
+ input_var = input.clone().requires_grad_()
+ output = module(input_var + 0)
+ self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
+ output.backward(input)
+ self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
+
+ # check eval mode doesn't change anything
+ for inplace in [True, False]:
+ module = cls(p, inplace).eval()
+ self.assertEqual(input, module(input))
+
+ # Check that these don't raise errors
+ module.__repr__()
+ str(module)
+
+ def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
+ # default case track_running_stats=False
+ b, c = input.size(0), input.size(1)
+ input_var = input.to(device=device, dtype=dtype).requires_grad_()
+
+ IN = cls(c, eps=0).to(device, dtype)
+
+ output = IN(input_var)
+ out_reshaped = output.view(b * c, -1)
+
+ mean = out_reshaped.mean(1)
+ var = out_reshaped.var(1, unbiased=False)
+
+ self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
+
+ # check that eval mode doesn't change behavior
+ grad_out = torch.randn_like(output)
+ res1 = output.data.clone()
+ output.backward(grad_out)
+ grad1 = input_var.grad.data.clone()
+
+ IN.eval()
+ output = IN(input_var)
+ input_var.grad = None
+ output.backward(grad_out)
+ res2 = output.data
+ grad2 = input_var.grad.data
+ self.assertEqual(res1, res2)
+ self.assertEqual(grad1, grad2)
+
+ # If track_running_stats=True and momentum=1, running_mean/var should be
+ # equal to mean/var of the input (with unbias correction)
+ IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
+
+ output = IN(input_var)
+
+ input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
+ mean = input_reshaped.mean(1)
+
+ input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
+ var = input_reshaped.var(2, unbiased=True)[:, :]
+
+ self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
+
+ # in eval mode, adding X * std to a channel in input should make the
+ # corresponding channel in output have mean X
+ IN.eval()
+ delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
+ delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
+ output = IN(input_var + delta)
+ self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
+
+ def _test_InstanceNorm_cuda_half(self, cls, input):
+ # THNN
+ input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
+ m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
+ thnn_output = m(input)
+ thnn_output.sum().backward()
+ thnn_input_grad = input.grad.data.clone()
+ self.assertEqual(thnn_output.type(), input.type())
+ # cuDNN
+ if TEST_CUDNN:
+ input.grad = None
+ m = m.float()
+ cudnn_output = m(input)
+ cudnn_output.sum().backward()
+ cudnn_input_grad = input.grad.data.clone()
+ self.assertEqual(cudnn_output.type(), input.type())
+ self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
+ self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
+
+ def _test_LayerNorm_general(self, device, dtype=torch.float):
+ for i in range(2, 6):
+ shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
+ x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+ normalized_ndim = random.randint(1, i - 1) # inclusive
+ normalized_shape = shape[-normalized_ndim:]
+ unnormalized_shape = shape[:-normalized_ndim]
+
+ # test that LN normalizes to mean 0 and stddev 1
+ ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
+ ln.weight.data.fill_(1)
+ ln.bias.data.fill_(0)
+ output = ln(x)
+ out_reshaped = output.view(*(unnormalized_shape + [-1]))
+ mean = out_reshaped.mean(-1)
+ var = out_reshaped.var(-1, unbiased=False)
+ self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
+
+ # test that LN applies weight and bias correctly
+ scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
+ ln.weight.data.fill_(scale)
+ ln.bias.data.fill_(bias)
+ output = ln(x)
+ out_reshaped = output.view(*(unnormalized_shape + [-1]))
+ mean = out_reshaped.mean(-1)
+ var = out_reshaped.var(-1, unbiased=False)
+ self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
+
+ bad_norm_shape_input_shape = {
+ (): (),
+ (2, 3): (3,),
+ (2,): (1, 2, 3),
+ (10,): (2, 3),
+ 10: (2, 3),
+ }
+ for norm_shape, input_shape in bad_norm_shape_input_shape.items():
+ ln = nn.LayerNorm(norm_shape)
+ input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
+ self.assertRaises(RuntimeError, lambda: ln(input))
+
+ def _test_LayerNorm_cuda_half(self):
+ input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
+ m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
+ output = m(input)
+ output.sum().backward()
+ self.assertEqual(output.type(), input.type())
+
+ def _test_GroupNorm_general(self, device, dtype=torch.float):
+ good_shape_g = {
+ (1, 2, 3, 4): 2,
+ (2, 3, 10): 3,
+ (3, 1, 1, 1, 2): 1,
+ (2, 6, 4, 2, 2): 3,
+ }
+ for shape, g in good_shape_g.items():
+ x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+ b = shape[0]
+ c = shape[1]
+
+ # test that GN normalizes to mean 0 and stddev 1
+ gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
+ gn.weight.data.fill_(1)
+ gn.bias.data.fill_(0)
+ output = gn(x)
+ out_reshaped = output.view(b, g, -1)
+ mean = out_reshaped.mean(-1)
+ var = out_reshaped.var(-1, unbiased=False)
+ self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
+
+ # test that GN applies weight and bias correctly
+ scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
+ bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
+ gn.weight.data.copy_(scale)
+ gn.bias.data.copy_(bias)
+ output = gn(x)
+ out_reshaped = output.view(b, c, -1)
+ out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
+ out_normed_reshaped = out_normed.view(b, g, -1)
+ mean = out_normed_reshaped.mean(-1)
+ var = out_normed_reshaped.var(-1, unbiased=False)
+ self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
+ self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
+
+ bad_shape_g = {
+ (1, 2, 3, 4): 3,
+ (2, 3, 10): 2,
+ (3, 1, 1, 1, 2): 10,
+ (2, 6, 4, 2, 2): 4,
+ }
+ for shape, g in bad_shape_g.items():
+ gn = nn.GroupNorm(g, shape[1])
+ input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+ self.assertRaises(RuntimeError, lambda: gn(input))
+
+ def _test_GroupNorm_cuda_half(self):
+ input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
+ m = nn.GroupNorm(2, 4).to("cuda", torch.half)
+ output = m(input)
+ output.sum().backward()
+ self.assertEqual(output.type(), input.type())
+
+class TestNNDeviceType(NNTestCase, GenericDeviceTypeHelpers):
+ def test_Dropout(self, device):
+ input = torch.Tensor(1000)
+ self._test_dropout(nn.Dropout, device, input)
+
+ def test_Dropout2d(self, device):
+ b = random.randint(1, 5)
+ w = random.randint(1, 5)
+ h = random.randint(1, 5)
+ num_features = 1000
+ input = torch.Tensor(num_features, b, w, h)
+ self._test_dropout(nn.Dropout2d, device, input)
+
+ def test_Dropout3d(self, device):
+ b = random.randint(1, 5)
+ w = random.randint(1, 5)
+ h = random.randint(1, 5)
+ d = random.randint(1, 2)
+ num_features = 1000
+ input = torch.Tensor(num_features, b, d, w, h)
+ self._test_dropout(nn.Dropout3d, device, input)
+
+ def test_InstanceNorm1d_general(self, device):
+ b = random.randint(3, 5)
+ c = random.randint(3, 5)
+ d = random.randint(8, 10)
+
+ input = torch.rand(b, c, d)
+ self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
+
+ if device == 'cuda':
+ self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
+
+ def test_InstanceNorm2d_general(self, device):
+ b = random.randint(3, 5)
+ c = random.randint(3, 5)
+ w = random.randint(3, 6)
+ h = random.randint(6, 8)
+
+ input = torch.rand(b, c, h, w)
+ self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
+
+ if device == 'cuda':
+ self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
+
+ def test_InstanceNorm3d_general(self, device):
+ b = random.randint(3, 5)
+ c = random.randint(3, 5)
+ w = random.randint(2, 5)
+ h = random.randint(2, 5)
+ d = random.randint(2, 5)
+
+ input = torch.rand(b, c, h, w, d)
+ self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
+
+ if device == 'cuda':
+ self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
+
+ def test_LayerNorm_general(self, device):
+ self._test_LayerNorm_general(device)
+
+ if device == 'cuda':
+ self._test_LayerNorm_cuda_half()
+
+ def test_GroupNorm_general(self, device):
+ self._test_GroupNorm_general(device)
+
+ if device == 'cuda':
+ self._test_GroupNorm_cuda_half()
+
+ def test_one_hot(self, device):
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
+ expected = torch.tensor([[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
+ expected = torch.tensor([[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
+ expected = torch.tensor([[0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
+ expected = torch.tensor([[[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1]],
+ [[0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
+ expected = torch.tensor([0, 0, 0, 0, 1], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
+ expected = torch.empty([4, 0, 100])
+ self.assertEqual(t, expected)
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
+
+ def test_nn_scalars(self, device):
+ # One off tests to ensure scalars from nn.yaml are properly applied
+ def verify_scalars(input, output):
+ if input.dim() == 0:
+ self.assertEqual((), output.shape)
+ else:
+ self.assertNotEqual((), output.shape)
+ output.sum().backward()
+ self.assertEqual(input.shape, input.grad.shape)
+
+ for input_shape in [(5, 6), ()]:
+ for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
+ torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
+ torch.nn.Tanh]:
+ input = torch.randn(input_shape, device=device, requires_grad=True)
+ m = module()
+ output = m(input)
+ verify_scalars(input, output)
+
+ def test_nn_scalars_reductions(self, device):
+ # One off tests to ensure scalars from nn.yaml are properly applied
+ def verify_reduction_scalars(input, reduction, output):
+ if reduction != 'none' or input.dim() == 0:
+ self.assertEqual((), output.shape)
+ else:
+ self.assertNotEqual((), output.shape)
+ output.sum().backward()
+ self.assertEqual(input.shape, input.grad.shape)
+
+ for input_shape in [(5, 6), ()]:
+ for reduction in ['none', 'mean', 'sum']:
+ for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
+ torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
+ input = torch.randn(input_shape, device=device, requires_grad=True)
+ target = torch.empty(input_shape, device=device).random_(2)
+ sigmoid = nn.Sigmoid()
+
+ input = torch.randn(input_shape, device=device, requires_grad=True)
+ m = module(reduction=reduction)
+ output = m(sigmoid(input), target)
+ verify_reduction_scalars(input, reduction, output)
+
+
+instantiate_device_type_tests(TestNNDeviceType, globals())
if __name__ == '__main__':
run_tests()
diff --git a/test/test_torch.py b/test/test_torch.py
index 7abdce0..69876d0 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -30,7 +30,7 @@
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
- IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, torchtest, \
+ IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf
from multiprocessing.reduction import ForkingPickler
from common_device_type import instantiate_device_type_tests, \
@@ -104,7 +104,7 @@
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
# run its methods as test cases.
-class _TestTorchMixin(torchtest):
+class _TestTorchMixin(object):
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True):
float_types = [torch.double,
torch.float]