Removes torchtest, expands generic device testing (#26374)

Summary:
- Removes torchtest
- <s>Moves test_torch tests skipped on ROCm to generic device test class</s>
- Creates test_nn generic device test class

Next: adding dtypes to generic device testing framework.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26374

Test Plan: Change is to tests themselves.

Differential Revision: D17442218

Pulled By: mruberry

fbshipit-source-id: d7e4451d09fc9049478b35a7efb8bb580071e8c8
diff --git a/test/common_utils.py b/test/common_utils.py
index d4b2f71..6040a11 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -216,52 +216,6 @@
     return run_test_function
 
 
-class torchtest():
-    """Allows to generate and run per-device unittests.
-
-    This decorator class allows to generate and run per-device unittest.
-
-    Example:
-
-    class _TestTorchMixin(torchtest):
-
-        @torchtest.for_all_device_types()
-        def test_zeros_like(self, device):
-            expected = torch.zeros((100, 100,), device=device)
-
-    Will execute:
-
-        test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.'
-        test_zeros_like_cpu (__main__.TestTorch) ... ok
-        test_zeros_like_cuda (__main__.TestTorch) ... ok
-
-    To work properly, test class should be inherited from `torchtest`.
-    for_all_device_types decorator does not guarantee proper functionality in
-    combination with other decorators.
-
-    Please do not extend this decorator to support other cases (such as dtype,
-    layouts, etc) without consulting with bigger group. Devices is the special
-    case as build flags control additions/removals (see
-    https://github.com/pytorch/pytorch/pull/23824 for the reference).
-    """
-    @classmethod
-    def for_all_device_types(cls):
-        def wrapper(fn):
-            test_names = []
-
-            for device in torch.testing.get_all_device_types():
-                test_name = fn.__name__ + '_' + device
-                assert not hasattr(cls, test_name), "Duplicated test name: " + test_name
-                setattr(cls, test_name, _test_function(fn, device))
-                test_names.append(test_name)
-
-            @wraps(fn)
-            def empty_test(*args, **kwargs):
-                raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names)))
-            return empty_test
-        return wrapper
-
-
 def skipIfNoLapack(fn):
     @wraps(fn)
     def wrapper(*args, **kwargs):
diff --git a/test/test_nn.py b/test/test_nn.py
index 7eed6b8..87721a7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -35,6 +35,7 @@
 from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
     module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
     ctcloss_reference, new_module_tests
+from common_device_type import instantiate_device_type_tests
 
 from torch.nn import MultiheadAttention
 
@@ -966,34 +967,6 @@
             with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'):
                 module(input)
 
-    def _test_dropout(self, cls, cuda, input):
-        p = 0.2
-        device = torch.device("cuda") if cuda else torch.device("cpu")
-        input = input.to(device).fill_(1 - p)
-
-        module = cls(p)
-        input_var = input.clone().requires_grad_()
-        output = module(input_var)
-        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
-        output.backward(input)
-        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
-
-        module = cls(p, True)
-        input_var = input.clone().requires_grad_()
-        output = module(input_var + 0)
-        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
-        output.backward(input)
-        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
-
-        # check eval mode doesn't change anything
-        for inplace in [True, False]:
-            module = cls(p, inplace).eval()
-            self.assertEqual(input, module(input))
-
-        # Check that these don't raise errors
-        module.__repr__()
-        str(module)
-
     def _test_alpha_dropout(self, cls, input):
         mean = input.mean()
         std = input.std()
@@ -3160,51 +3133,6 @@
         gradcheck(func, [x])
         gradgradcheck(func, [x])
 
-    def test_Dropout(self):
-        input = torch.Tensor(1000)
-        self._test_dropout(nn.Dropout, False, input)
-
-    def test_Dropout2d(self):
-        b = random.randint(1, 5)
-        w = random.randint(1, 5)
-        h = random.randint(1, 5)
-        num_features = 1000
-        input = torch.Tensor(num_features, b, w, h)
-        self._test_dropout(nn.Dropout2d, False, input)
-
-    def test_Dropout3d(self):
-        b = random.randint(1, 5)
-        w = random.randint(1, 5)
-        h = random.randint(1, 5)
-        d = random.randint(1, 2)
-        num_features = 1000
-        input = torch.Tensor(num_features, b, d, w, h)
-        self._test_dropout(nn.Dropout3d, False, input)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_Dropout_cuda(self):
-        input = torch.Tensor(1000)
-        self._test_dropout(nn.Dropout, True, input)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_Dropout2d_cuda(self):
-        b = random.randint(1, 5)
-        w = random.randint(1, 5)
-        h = random.randint(1, 5)
-        num_features = 1000
-        input = torch.Tensor(num_features, b, w, h)
-        self._test_dropout(nn.Dropout2d, True, input)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_Dropout3d_cuda(self):
-        b = random.randint(1, 5)
-        w = random.randint(1, 5)
-        h = random.randint(1, 5)
-        d = random.randint(1, 2)
-        num_features = 1000
-        input = torch.Tensor(num_features, b, d, w, h)
-        self._test_dropout(nn.Dropout3d, True, input)
-
     def test_AlphaDropout(self):
         # generate random tensor with zero mean and unit std
         input = torch.randn(5000)
@@ -3219,259 +3147,6 @@
         input = torch.randn(num_features, b, d, w, h)
         self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
 
-    def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float):
-        # default case track_running_stats=False
-        b, c = input.size(0), input.size(1)
-        input_var = input.to(device=device, dtype=dtype).requires_grad_()
-
-        IN = cls(c, eps=0).to(device, dtype)
-
-        output = IN(input_var)
-        out_reshaped = output.view(b * c, -1)
-
-        mean = out_reshaped.mean(1)
-        var = out_reshaped.var(1, unbiased=False)
-
-        self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
-        self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
-
-        # check that eval mode doesn't change behavior
-        grad_out = torch.randn_like(output)
-        res1 = output.data.clone()
-        output.backward(grad_out)
-        grad1 = input_var.grad.data.clone()
-
-        IN.eval()
-        output = IN(input_var)
-        input_var.grad = None
-        output.backward(grad_out)
-        res2 = output.data
-        grad2 = input_var.grad.data
-        self.assertEqual(res1, res2)
-        self.assertEqual(grad1, grad2)
-
-        # If track_running_stats=True and momentum=1, running_mean/var should be
-        # equal to mean/var of the input (with unbias correction)
-        IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
-
-        output = IN(input_var)
-
-        input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
-        mean = input_reshaped.mean(1)
-
-        input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
-        var = input_reshaped.var(2, unbiased=True)[:, :]
-
-        self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
-        self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
-
-        # in eval mode, adding X * std to a channel in input should make the
-        # corresponding channel in output have mean X
-        IN.eval()
-        delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
-        delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
-        output = IN(input_var + delta)
-        self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
-
-    def _test_InstanceNorm_cuda_half(self, cls, input):
-        # THNN
-        input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
-        m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
-        thnn_output = m(input)
-        thnn_output.sum().backward()
-        thnn_input_grad = input.grad.data.clone()
-        self.assertEqual(thnn_output.type(), input.type())
-        # cuDNN
-        if TEST_CUDNN:
-            input.grad = None
-            m = m.float()
-            cudnn_output = m(input)
-            cudnn_output.sum().backward()
-            cudnn_input_grad = input.grad.data.clone()
-            self.assertEqual(cudnn_output.type(), input.type())
-            self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
-            self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
-
-    def test_InstanceNorm1d_general(self):
-        b = random.randint(3, 5)
-        c = random.randint(3, 5)
-        d = random.randint(8, 10)
-
-        input = torch.rand(b, c, d)
-        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, dtype=torch.float)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_InstanceNorm1d_general_cuda(self):
-        b = random.randint(3, 5)
-        c = random.randint(3, 5)
-        d = random.randint(8, 10)
-
-        input = torch.rand(b, c, d)
-        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, "cuda", torch.float)
-        self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
-
-    def test_InstanceNorm2d_general(self):
-        b = random.randint(3, 5)
-        c = random.randint(3, 5)
-        w = random.randint(3, 6)
-        h = random.randint(6, 8)
-
-        input = torch.rand(b, c, h, w)
-        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, dtype=torch.float)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_InstanceNorm2d_general_cuda(self):
-        b = random.randint(3, 5)
-        c = random.randint(3, 5)
-        w = random.randint(3, 6)
-        h = random.randint(6, 8)
-
-        input = torch.rand(b, c, h, w)
-        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, "cuda", torch.float)
-        self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
-
-    def test_InstanceNorm3d_general(self):
-        b = random.randint(3, 5)
-        c = random.randint(3, 5)
-        w = random.randint(2, 5)
-        h = random.randint(2, 5)
-        d = random.randint(2, 5)
-
-        input = torch.rand(b, c, h, w, d)
-        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, dtype=torch.float)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_InstanceNorm3d_general_cuda(self):
-        b = random.randint(3, 5)
-        c = random.randint(2, 5)
-        w = random.randint(2, 5)
-        h = random.randint(2, 5)
-        d = random.randint(2, 5)
-
-        input = torch.rand(b, c, h, w, d)
-        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, "cuda", torch.float)
-        self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
-
-    def _test_LayerNorm_general(self, device="cpu", dtype=torch.float):
-        for i in range(2, 6):
-            shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
-            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
-            normalized_ndim = random.randint(1, i - 1)  # inclusive
-            normalized_shape = shape[-normalized_ndim:]
-            unnormalized_shape = shape[:-normalized_ndim]
-
-            # test that LN normalizes to mean 0 and stddev 1
-            ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
-            ln.weight.data.fill_(1)
-            ln.bias.data.fill_(0)
-            output = ln(x)
-            out_reshaped = output.view(*(unnormalized_shape + [-1]))
-            mean = out_reshaped.mean(-1)
-            var = out_reshaped.var(-1, unbiased=False)
-            self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
-            self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
-
-            # test that LN applies weight and bias correctly
-            scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
-            ln.weight.data.fill_(scale)
-            ln.bias.data.fill_(bias)
-            output = ln(x)
-            out_reshaped = output.view(*(unnormalized_shape + [-1]))
-            mean = out_reshaped.mean(-1)
-            var = out_reshaped.var(-1, unbiased=False)
-            self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
-            self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
-
-        bad_norm_shape_input_shape = {
-            (): (),
-            (2, 3): (3,),
-            (2,): (1, 2, 3),
-            (10,): (2, 3),
-            10: (2, 3),
-        }
-        for norm_shape, input_shape in bad_norm_shape_input_shape.items():
-            ln = nn.LayerNorm(norm_shape)
-            input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
-            self.assertRaises(RuntimeError, lambda: ln(input))
-
-    def _test_LayerNorm_cuda_half(self):
-        input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
-        m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
-        output = m(input)
-        output.sum().backward()
-        self.assertEqual(output.type(), input.type())
-
-    def test_LayerNorm_general(self):
-        self._test_LayerNorm_general()
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_LayerNorm_general_cuda(self):
-        self._test_LayerNorm_general("cuda")
-        self._test_LayerNorm_cuda_half()
-
-    def _test_GroupNorm_general(self, device="cpu", dtype=torch.float):
-        good_shape_g = {
-            (1, 2, 3, 4): 2,
-            (2, 3, 10): 3,
-            (3, 1, 1, 1, 2): 1,
-            (2, 6, 4, 2, 2): 3,
-        }
-        for shape, g in good_shape_g.items():
-            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
-            b = shape[0]
-            c = shape[1]
-
-            # test that GN normalizes to mean 0 and stddev 1
-            gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
-            gn.weight.data.fill_(1)
-            gn.bias.data.fill_(0)
-            output = gn(x)
-            out_reshaped = output.view(b, g, -1)
-            mean = out_reshaped.mean(-1)
-            var = out_reshaped.var(-1, unbiased=False)
-            self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
-            self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
-
-            # test that GN applies weight and bias correctly
-            scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
-            bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
-            gn.weight.data.copy_(scale)
-            gn.bias.data.copy_(bias)
-            output = gn(x)
-            out_reshaped = output.view(b, c, -1)
-            out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
-            out_normed_reshaped = out_normed.view(b, g, -1)
-            mean = out_normed_reshaped.mean(-1)
-            var = out_normed_reshaped.var(-1, unbiased=False)
-            self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
-            self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
-
-        bad_shape_g = {
-            (1, 2, 3, 4): 3,
-            (2, 3, 10): 2,
-            (3, 1, 1, 1, 2): 10,
-            (2, 6, 4, 2, 2): 4,
-        }
-        for shape, g in bad_shape_g.items():
-            gn = nn.GroupNorm(g, shape[1])
-            input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
-            self.assertRaises(RuntimeError, lambda: gn(input))
-
-    def _test_GroupNorm_cuda_half(self):
-        input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
-        m = nn.GroupNorm(2, 4).to("cuda", torch.half)
-        output = m(input)
-        output.sum().backward()
-        self.assertEqual(output.type(), input.type())
-
-    def test_GroupNorm_general(self):
-        self._test_GroupNorm_general(dtype=torch.float)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_GroupNorm_general_cuda(self):
-        self._test_GroupNorm_general("cuda", torch.float)
-        self._test_GroupNorm_cuda_half()
-
     def test_pad(self):
         inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
         _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
@@ -3490,115 +3165,11 @@
         self.assertRaisesRegex(RuntimeError, expected_err_msg,
                                lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
 
-    @staticmethod
-    def _test_one_hot(self, use_cuda=False):
-        device = torch.device('cuda' if use_cuda else 'cpu')
-        with self.assertRaises(RuntimeError):
-            torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
-
-        with self.assertRaises(RuntimeError):
-            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
-
-        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
-        expected = torch.tensor([[0, 0, 0, 1, 0],
-                                 [0, 0, 0, 0, 1],
-                                 [0, 1, 0, 0, 0],
-                                 [1, 0, 0, 0, 0]], device=device)
-        self.assertEqual(t, expected)
-
-        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
-        expected = torch.tensor([[0, 0, 0, 1, 0],
-                                 [0, 0, 0, 0, 1],
-                                 [0, 1, 0, 0, 0],
-                                 [1, 0, 0, 0, 0]], device=device)
-        self.assertEqual(t, expected)
-
-        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
-        expected = torch.tensor([[0, 0, 0, 1, 0, 0],
-                                 [0, 0, 0, 0, 1, 0],
-                                 [0, 1, 0, 0, 0, 0],
-                                 [1, 0, 0, 0, 0, 0]], device=device)
-        self.assertEqual(t, expected)
-
-        t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
-        expected = torch.tensor([[[0, 0, 0, 1, 0],
-                                  [0, 0, 0, 0, 1]],
-                                 [[0, 1, 0, 0, 0],
-                                  [1, 0, 0, 0, 0]]], device=device)
-        self.assertEqual(t, expected)
-
-        t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
-        expected = torch.tensor([0, 0, 0, 0, 1], device=device)
-        self.assertEqual(t, expected)
-
-        t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
-        expected = torch.empty([4, 0, 100])
-        self.assertEqual(t, expected)
-
-        with self.assertRaises(RuntimeError):
-            torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
-
-        with self.assertRaises(RuntimeError):
-            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
-
-    def test_one_hot(self):
-        self._test_one_hot(self)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_one_hot_cuda(self):
-        self._test_one_hot(self, use_cuda=True)
-
     def test_pad_scalar_error(self):
         inputs = torch.tensor(0., requires_grad=True)
         self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
         self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
 
-    def test_nn_scalars(self):
-        # One off tests to ensure scalars from nn.yaml are properly applied
-        def verify_scalars(input, output):
-            if input.dim() == 0:
-                self.assertEqual((), output.shape)
-            else:
-                self.assertNotEqual((), output.shape)
-            output.sum().backward()
-            self.assertEqual(input.shape, input.grad.shape)
-
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
-            for input_shape in [(5, 6), ()]:
-                for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
-                               torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
-                               torch.nn.Tanh]:
-                    input = torch.randn(input_shape, device=device, requires_grad=True)
-                    m = module()
-                    output = m(input)
-                    verify_scalars(input, output)
-
-    def test_nn_scalars_reductions(self):
-        # One off tests to ensure scalars from nn.yaml are properly applied
-        def verify_reduction_scalars(input, reduction, output):
-            if reduction != 'none' or input.dim() == 0:
-                self.assertEqual((), output.shape)
-            else:
-                self.assertNotEqual((), output.shape)
-            output.sum().backward()
-            self.assertEqual(input.shape, input.grad.shape)
-
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
-            for input_shape in [(5, 6), ()]:
-                for reduction in ['none', 'mean', 'sum']:
-                    for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
-                                   torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
-                        input = torch.randn(input_shape, device=device, requires_grad=True)
-                        target = torch.empty(input_shape, device=device).random_(2)
-                        sigmoid = nn.Sigmoid()
-
-                        input = torch.randn(input_shape, device=device, requires_grad=True)
-                        m = module(reduction=reduction)
-                        output = m(sigmoid(input), target)
-                        verify_reduction_scalars(input, reduction, output)
-
     @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
                      "Scipy v1.0 and/or numpy not found")
     def test_multihead_attention(self):
@@ -10090,6 +9661,374 @@
     return transform_tensor, transform_ary, grid_ary
 # end TestNN.test_affine_* helpers
 
+class GenericDeviceTypeHelpers(object):
+    def _test_dropout(self, cls, device, input):
+        p = 0.2
+        input = input.to(device).fill_(1 - p)
+
+        module = cls(p)
+        input_var = input.clone().requires_grad_()
+        output = module(input_var)
+        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
+        output.backward(input)
+        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
+
+        module = cls(p, True)
+        input_var = input.clone().requires_grad_()
+        output = module(input_var + 0)
+        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
+        output.backward(input)
+        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
+
+        # check eval mode doesn't change anything
+        for inplace in [True, False]:
+            module = cls(p, inplace).eval()
+            self.assertEqual(input, module(input))
+
+        # Check that these don't raise errors
+        module.__repr__()
+        str(module)
+
+    def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
+        # default case track_running_stats=False
+        b, c = input.size(0), input.size(1)
+        input_var = input.to(device=device, dtype=dtype).requires_grad_()
+
+        IN = cls(c, eps=0).to(device, dtype)
+
+        output = IN(input_var)
+        out_reshaped = output.view(b * c, -1)
+
+        mean = out_reshaped.mean(1)
+        var = out_reshaped.var(1, unbiased=False)
+
+        self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
+        self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
+
+        # check that eval mode doesn't change behavior
+        grad_out = torch.randn_like(output)
+        res1 = output.data.clone()
+        output.backward(grad_out)
+        grad1 = input_var.grad.data.clone()
+
+        IN.eval()
+        output = IN(input_var)
+        input_var.grad = None
+        output.backward(grad_out)
+        res2 = output.data
+        grad2 = input_var.grad.data
+        self.assertEqual(res1, res2)
+        self.assertEqual(grad1, grad2)
+
+        # If track_running_stats=True and momentum=1, running_mean/var should be
+        # equal to mean/var of the input (with unbias correction)
+        IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
+
+        output = IN(input_var)
+
+        input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
+        mean = input_reshaped.mean(1)
+
+        input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
+        var = input_reshaped.var(2, unbiased=True)[:, :]
+
+        self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
+        self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
+
+        # in eval mode, adding X * std to a channel in input should make the
+        # corresponding channel in output have mean X
+        IN.eval()
+        delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
+        delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
+        output = IN(input_var + delta)
+        self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
+
+    def _test_InstanceNorm_cuda_half(self, cls, input):
+        # THNN
+        input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
+        m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
+        thnn_output = m(input)
+        thnn_output.sum().backward()
+        thnn_input_grad = input.grad.data.clone()
+        self.assertEqual(thnn_output.type(), input.type())
+        # cuDNN
+        if TEST_CUDNN:
+            input.grad = None
+            m = m.float()
+            cudnn_output = m(input)
+            cudnn_output.sum().backward()
+            cudnn_input_grad = input.grad.data.clone()
+            self.assertEqual(cudnn_output.type(), input.type())
+            self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
+            self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
+
+    def _test_LayerNorm_general(self, device, dtype=torch.float):
+        for i in range(2, 6):
+            shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
+            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+            normalized_ndim = random.randint(1, i - 1)  # inclusive
+            normalized_shape = shape[-normalized_ndim:]
+            unnormalized_shape = shape[:-normalized_ndim]
+
+            # test that LN normalizes to mean 0 and stddev 1
+            ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
+            ln.weight.data.fill_(1)
+            ln.bias.data.fill_(0)
+            output = ln(x)
+            out_reshaped = output.view(*(unnormalized_shape + [-1]))
+            mean = out_reshaped.mean(-1)
+            var = out_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
+
+            # test that LN applies weight and bias correctly
+            scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
+            ln.weight.data.fill_(scale)
+            ln.bias.data.fill_(bias)
+            output = ln(x)
+            out_reshaped = output.view(*(unnormalized_shape + [-1]))
+            mean = out_reshaped.mean(-1)
+            var = out_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
+
+        bad_norm_shape_input_shape = {
+            (): (),
+            (2, 3): (3,),
+            (2,): (1, 2, 3),
+            (10,): (2, 3),
+            10: (2, 3),
+        }
+        for norm_shape, input_shape in bad_norm_shape_input_shape.items():
+            ln = nn.LayerNorm(norm_shape)
+            input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
+            self.assertRaises(RuntimeError, lambda: ln(input))
+
+    def _test_LayerNorm_cuda_half(self):
+        input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
+        m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
+        output = m(input)
+        output.sum().backward()
+        self.assertEqual(output.type(), input.type())
+
+    def _test_GroupNorm_general(self, device, dtype=torch.float):
+        good_shape_g = {
+            (1, 2, 3, 4): 2,
+            (2, 3, 10): 3,
+            (3, 1, 1, 1, 2): 1,
+            (2, 6, 4, 2, 2): 3,
+        }
+        for shape, g in good_shape_g.items():
+            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+            b = shape[0]
+            c = shape[1]
+
+            # test that GN normalizes to mean 0 and stddev 1
+            gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
+            gn.weight.data.fill_(1)
+            gn.bias.data.fill_(0)
+            output = gn(x)
+            out_reshaped = output.view(b, g, -1)
+            mean = out_reshaped.mean(-1)
+            var = out_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
+
+            # test that GN applies weight and bias correctly
+            scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
+            bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
+            gn.weight.data.copy_(scale)
+            gn.bias.data.copy_(bias)
+            output = gn(x)
+            out_reshaped = output.view(b, c, -1)
+            out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
+            out_normed_reshaped = out_normed.view(b, g, -1)
+            mean = out_normed_reshaped.mean(-1)
+            var = out_normed_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
+
+        bad_shape_g = {
+            (1, 2, 3, 4): 3,
+            (2, 3, 10): 2,
+            (3, 1, 1, 1, 2): 10,
+            (2, 6, 4, 2, 2): 4,
+        }
+        for shape, g in bad_shape_g.items():
+            gn = nn.GroupNorm(g, shape[1])
+            input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
+            self.assertRaises(RuntimeError, lambda: gn(input))
+
+    def _test_GroupNorm_cuda_half(self):
+        input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
+        m = nn.GroupNorm(2, 4).to("cuda", torch.half)
+        output = m(input)
+        output.sum().backward()
+        self.assertEqual(output.type(), input.type())
+
+class TestNNDeviceType(NNTestCase, GenericDeviceTypeHelpers):
+    def test_Dropout(self, device):
+        input = torch.Tensor(1000)
+        self._test_dropout(nn.Dropout, device, input)
+
+    def test_Dropout2d(self, device):
+        b = random.randint(1, 5)
+        w = random.randint(1, 5)
+        h = random.randint(1, 5)
+        num_features = 1000
+        input = torch.Tensor(num_features, b, w, h)
+        self._test_dropout(nn.Dropout2d, device, input)
+
+    def test_Dropout3d(self, device):
+        b = random.randint(1, 5)
+        w = random.randint(1, 5)
+        h = random.randint(1, 5)
+        d = random.randint(1, 2)
+        num_features = 1000
+        input = torch.Tensor(num_features, b, d, w, h)
+        self._test_dropout(nn.Dropout3d, device, input)
+
+    def test_InstanceNorm1d_general(self, device):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
+        d = random.randint(8, 10)
+
+        input = torch.rand(b, c, d)
+        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
+
+        if device == 'cuda':
+            self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
+
+    def test_InstanceNorm2d_general(self, device):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
+        w = random.randint(3, 6)
+        h = random.randint(6, 8)
+
+        input = torch.rand(b, c, h, w)
+        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
+
+        if device == 'cuda':
+            self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
+
+    def test_InstanceNorm3d_general(self, device):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
+        w = random.randint(2, 5)
+        h = random.randint(2, 5)
+        d = random.randint(2, 5)
+
+        input = torch.rand(b, c, h, w, d)
+        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
+
+        if device == 'cuda':
+            self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
+
+    def test_LayerNorm_general(self, device):
+        self._test_LayerNorm_general(device)
+
+        if device == 'cuda':
+            self._test_LayerNorm_cuda_half()
+
+    def test_GroupNorm_general(self, device):
+        self._test_GroupNorm_general(device)
+
+        if device == 'cuda':
+            self._test_GroupNorm_cuda_half()
+
+    def test_one_hot(self, device):
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
+        expected = torch.tensor([[0, 0, 0, 1, 0],
+                                 [0, 0, 0, 0, 1],
+                                 [0, 1, 0, 0, 0],
+                                 [1, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
+        expected = torch.tensor([[0, 0, 0, 1, 0],
+                                 [0, 0, 0, 0, 1],
+                                 [0, 1, 0, 0, 0],
+                                 [1, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
+        expected = torch.tensor([[0, 0, 0, 1, 0, 0],
+                                 [0, 0, 0, 0, 1, 0],
+                                 [0, 1, 0, 0, 0, 0],
+                                 [1, 0, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
+        expected = torch.tensor([[[0, 0, 0, 1, 0],
+                                  [0, 0, 0, 0, 1]],
+                                 [[0, 1, 0, 0, 0],
+                                  [1, 0, 0, 0, 0]]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
+        expected = torch.tensor([0, 0, 0, 0, 1], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
+        expected = torch.empty([4, 0, 100])
+        self.assertEqual(t, expected)
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
+
+    def test_nn_scalars(self, device):
+        # One off tests to ensure scalars from nn.yaml are properly applied
+        def verify_scalars(input, output):
+            if input.dim() == 0:
+                self.assertEqual((), output.shape)
+            else:
+                self.assertNotEqual((), output.shape)
+            output.sum().backward()
+            self.assertEqual(input.shape, input.grad.shape)
+
+        for input_shape in [(5, 6), ()]:
+            for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
+                           torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
+                           torch.nn.Tanh]:
+                input = torch.randn(input_shape, device=device, requires_grad=True)
+                m = module()
+                output = m(input)
+                verify_scalars(input, output)
+
+    def test_nn_scalars_reductions(self, device):
+        # One off tests to ensure scalars from nn.yaml are properly applied
+        def verify_reduction_scalars(input, reduction, output):
+            if reduction != 'none' or input.dim() == 0:
+                self.assertEqual((), output.shape)
+            else:
+                self.assertNotEqual((), output.shape)
+            output.sum().backward()
+            self.assertEqual(input.shape, input.grad.shape)
+
+        for input_shape in [(5, 6), ()]:
+            for reduction in ['none', 'mean', 'sum']:
+                for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
+                               torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
+                    input = torch.randn(input_shape, device=device, requires_grad=True)
+                    target = torch.empty(input_shape, device=device).random_(2)
+                    sigmoid = nn.Sigmoid()
+
+                    input = torch.randn(input_shape, device=device, requires_grad=True)
+                    m = module(reduction=reduction)
+                    output = m(sigmoid(input), target)
+                    verify_reduction_scalars(input, reduction, output)
+
+
+instantiate_device_type_tests(TestNNDeviceType, globals())
 
 if __name__ == '__main__':
     run_tests()
diff --git a/test/test_torch.py b/test/test_torch.py
index 7abdce0..69876d0 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -30,7 +30,7 @@
 from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
     TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
     IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
-    IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, torchtest, \
+    IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
     skipCUDANonDefaultStreamIf
 from multiprocessing.reduction import ForkingPickler
 from common_device_type import instantiate_device_type_tests, \
@@ -104,7 +104,7 @@
 
 # This is intentionally prefixed by an underscore. Otherwise pytest will try to
 # run its methods as test cases.
-class _TestTorchMixin(torchtest):
+class _TestTorchMixin(object):
     def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True):
         float_types = [torch.double,
                        torch.float]