Creates device generic cuDNN decorators (#26791)

Summary:
- Creates skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan decorators
- Makes several test_nn.py tests generic

Many tests in test_nn.py test cuDNN. These tests are guarded on various conditionals using TEST_CUDNN and TEST_CUDNN_VERSION imported from common_cuda.py and custom error messages like 'CUDNN not available' and 'needs cudnn.'

This PR suggests using the CUDA base test class instead of common_cuda.py to test cuDNN's availability, at least on generic tests. The CUDA base test class is preferable to common_cuda.py since it only creates a CUDA context if its tests are run. Importing from common_cuda.py, on the other hand, always creates a CUDA context. Using the CUDA base test class is also consistent with how other generic tests are guarded and provides consistent skip messages.

One quirk to this approach is that it makes use of the self argument to the test functions to check for cuDNN availability during a test. See test_rnn_retain_variables. The self argument could also be used to check the device type instead of the more verbose torch.device(device).type == 'cuda'.

An alternative approach to making test_nn.py generic would be to continue to use common_cuda.py imports, try to keep their skip messages consistent, and not worry about creating unnecessary CUDA contexts. This would preclude writing generic tests that can only run on CUDA if cuDNN is available, however, so tests like "_test_RNN_cpu_vs_cudnn" would require additional changes to make into device generic precision tests like "_test_RNN_cpu_vs_xla."

For consistency, simplicity, and ease of use, I recommend we adopt the proposed decorators and make use of the self argument when productive.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26791

Differential Revision: D17678325

Pulled By: mruberry

fbshipit-source-id: 1794735ede9bc9f36856e72b3804b136ad3e0de2
diff --git a/test/common_device_type.py b/test/common_device_type.py
index 131f1f9..101a303 100644
--- a/test/common_device_type.py
+++ b/test/common_device_type.py
@@ -184,6 +184,9 @@
     _do_cuda_memory_leak_check = True
     _do_cuda_non_default_stream = True
 
+    def has_cudnn(self):
+        return not self.no_cudnn
+
     @classmethod
     def get_primary_device(cls):
         return cls.primary_device
@@ -202,9 +205,13 @@
     @classmethod
     def setUpClass(cls):
         # has_magma shows up after cuda is initialized
-        torch.ones(1).cuda()
+        t = torch.ones(1).cuda()
         cls.no_magma = not torch.cuda.has_magma
 
+        # Determines if cuDNN is available and its version
+        cls.no_cudnn = not (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(t))
+        cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version()
+
         # Acquires the current device as the primary (test) device
         cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device())
 
@@ -423,3 +430,27 @@
 # Skips a test on CUDA when using ROCm.
 def skipCUDAIfRocm(fn):
     return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn)
+
+
+# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested.
+def skipCUDAIfCudnnVersionLessThan(version=0):
+
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, device, *args, **kwargs):
+            if self.device_type == 'cuda':
+                if self.no_cudnn:
+                    reason = "cuDNN not available"
+                    raise unittest.SkipTest(reason)
+                if self.cudnn_version < version:
+                    reason = "cuDNN version {0} is available but {1} required".format(self.cudnn_version, version)
+                    raise unittest.SkipTest(reason)
+
+            return fn(self, device, *args, **kwargs)
+
+        return wrap_fn
+    return dec_fn
+
+
+def skipCUDAIfNoCudnn(fn):
+    return skipCUDAIfCudnnVersionLessThan(0)(fn)
diff --git a/test/test_nn.py b/test/test_nn.py
index 65f7c4d..7ec149c 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -37,7 +37,7 @@
     module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
     ctcloss_reference, new_module_tests
 from common_device_type import instantiate_device_type_tests, dtypes, \
-    dtypesIfCUDA
+    dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA
 
 from torch.nn import MultiheadAttention
 
@@ -2478,91 +2478,6 @@
         expected_output = fc_op(X, W, b)
         torch.testing.assert_allclose(expected_output, actual_output.cpu(), atol=1e-3, rtol=1e-3)
 
-    def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected):
-        logits = torch.randn(shape, dtype=torch.float)
-        logits = logits.to(dtype)
-        if cuda:
-            logits = logits.cuda()
-
-        y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
-
-        # All values positive
-        self.assertGreaterEqual(y_draw.min(), 0)
-        # Shape unchanged
-        self.assertTrue(y_draw.shape == logits.shape)
-        # One choice per draw
-        self.assertEqual(y_draw.sum(), count_expected, prec=torch.finfo(y_draw.dtype).eps)
-
-    def _test_gumbel_softmax_straight_through(self, cuda, dtype):
-        num_draws = 100
-
-        logits = torch.tensor([[0.2, 0.8, 0.1]])
-        logits = logits.reshape([1, 3])
-        logits = logits.to(dtype).requires_grad_()
-        if cuda:
-            logits = logits.cuda()
-        probs = logits.softmax(dim=-1)
-
-        counts = torch.zeros_like(logits)
-        for _ in range(num_draws):
-            y_draw = F.gumbel_softmax(logits, hard=True)
-            counts = counts + y_draw
-
-        # All values positive
-        self.assertGreaterEqual(y_draw.min(), 0)
-        # Each experiment should result in 1 draw.
-        self.assertEqual(counts.sum(), num_draws, prec=torch.finfo(counts.dtype).eps)
-
-        # check results is asymptotically as expected.
-        expected = probs * num_draws
-        # ~z is approximately N(0,1) for unbiased count
-        z = (counts - expected) / (expected * (1 - probs)).sqrt()
-        # A (lazy) approximate 99% two-sided test:
-        # occurs with prob alpha~>=0.01 if unbiased
-        self.assertLess(z.abs().max().item(), 2.58)
-
-    def _test_gumbel_softmax_grad(self, cuda, dtype):
-        # "hard" and "not hard" should propagate same gradient.
-        device = torch.device("cuda") if cuda else torch.device("cpu")
-        logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
-        logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
-
-        seed = torch.random.get_rng_state()
-        y_soft = F.gumbel_softmax(logits_soft, hard=False)
-        torch.random.set_rng_state(seed)
-        y_hard = F.gumbel_softmax(logits_hard, hard=True)
-
-        y_soft.sum().backward()
-        y_hard.sum().backward()
-
-        # 2eps = 1x addition + 1x subtraction.
-        tol = 2 * torch.finfo(dtype).eps
-        self.assertAlmostEqual(logits_soft.grad, logits_hard.grad, delta=tol)
-
-    @repeat_test_for_types(NO_HALF_TENSORTYPES)
-    def test_gumbel_softmax(self, dtype=torch.float):
-        """
-        NO_HALF_TENSORTYPES because many half-ops doesnt work on cpu.
-        """
-        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=0, count_expected=1)
-        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=-1, count_expected=1)
-        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
-        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
-        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
-        self._test_gumbel_softmax_straight_through(cuda=False, dtype=dtype)
-        self._test_gumbel_softmax_grad(cuda=False, dtype=dtype)
-
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    @repeat_test_for_types(ALL_TENSORTYPES)
-    def test_gumbel_softmax_cuda(self, dtype=torch.float):
-        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=0, count_expected=1)
-        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=-1, count_expected=1)
-        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
-        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
-        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
-        self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype)
-        self._test_gumbel_softmax_grad(cuda=True, dtype=dtype)
-
     def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
                                         mode='mean',
                                         device='cpu',
@@ -2794,18 +2709,6 @@
 
         y.backward(grad)
 
-    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
-    def test_contig_wrong_stride_cudnn(self):
-        # x has to have batch_size 1 to test contiguous checks
-        x = torch.randn(1, 16, 5, 5, device="cuda")
-        stride = list(x.stride())
-        stride[0] = 20
-        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
-        x.set_(x.storage(), 0, x.size(), stride)
-        self.assertTrue(x.is_contiguous())
-        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device="cuda"))
-        F.conv2d(x, torch.randn(1, 16, 1, 1, device="cuda"))
-
     def test_embedding_bag(self):
         for dtype in [torch.double, torch.float]:
             self._test_EmbeddingBag(False, 'sum', False, dtype=dtype)
@@ -5023,19 +4926,6 @@
 
         self.assertEqual(l, expected)
 
-    @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 7000), "needs cudnn >= 7.0")
-    def test_CTCLoss_cudnn(self):
-        target_lengths = [30, 25, 20]
-        input_lengths = [50, 50, 50]
-        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
-        log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2)
-        res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
-        expected = ctcloss_reference(log_probs, targets.cuda(), input_lengths, target_lengths).float()
-        with torch.backends.cudnn.flags(enabled=False):
-            res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths)
-        self.assertEqual(res, expected)
-        self.assertEqual(res2, res)
-
     def test_CTCLoss_typechecks(self):
         target_lengths = torch.tensor([30, 25, 20])
         input_lengths = torch.tensor([50, 50, 50])
@@ -6343,32 +6233,6 @@
             self.assertEqual(output1, output2)
             self.assertEqual(hidden1, hidden2)
 
-    def _test_rnn_retain_variables(self, device="cpu", dtype=torch.double):
-        rnns = [nn.LSTM(10, 20, num_layers=2).to(device, dtype),
-                nn.GRU(10, 20, num_layers=2).to(device, dtype),
-                nn.RNN(10, 20, num_layers=2).to(device, dtype)]
-        for rnn in rnns:
-            input = torch.randn(5, 6, 10, device=device, dtype=dtype, requires_grad=True)
-            output = rnn(input)
-            output[0].sum().backward(retain_graph=True)
-            grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
-            for _ in range(4):
-                rnn.zero_grad()
-                input.grad.data.zero_()
-                output[0].sum().backward(retain_graph=True)
-                grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
-                self.assertEqual(grads, grads2)
-
-    def test_rnn_retain_variables(self):
-        self._test_rnn_retain_variables()
-
-    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
-    @repeat_test_for_types(ALL_TENSORTYPES)
-    def test_rnn_retain_variables_cuda(self, dtype=torch.float):
-        with torch.backends.cudnn.flags(enabled=False):
-            self._test_rnn_retain_variables("cuda", dtype)
-        self._test_rnn_retain_variables("cuda", dtype)
-
     def _test_RNN_cpu_vs_cudnn(self, dropout):
 
         def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights_val):
@@ -9994,6 +9858,126 @@
             out2 = conv1(input_c)
             self.assertEqual(out1, out2)
 
+    def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
+        logits = torch.randn(shape, dtype=torch.float, device=device)
+        logits = logits.to(dtype)
+
+        y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
+
+        # All values positive
+        self.assertGreaterEqual(y_draw.min(), 0)
+        # Shape unchanged
+        self.assertTrue(y_draw.shape == logits.shape)
+        # One choice per draw
+        self.assertEqual(y_draw.sum(), count_expected, prec=torch.finfo(y_draw.dtype).eps)
+
+    def _test_gumbel_softmax_straight_through(self, device, dtype):
+        num_draws = 100
+
+        logits = torch.tensor([[0.2, 0.8, 0.1]], device=device)
+        logits = logits.reshape([1, 3])
+        logits = logits.to(dtype).requires_grad_()
+        probs = logits.softmax(dim=-1)
+
+        counts = torch.zeros_like(logits)
+        for _ in range(num_draws):
+            y_draw = F.gumbel_softmax(logits, hard=True)
+            counts = counts + y_draw
+
+        # All values positive
+        self.assertGreaterEqual(y_draw.min(), 0)
+        # Each experiment should result in 1 draw.
+        self.assertEqual(counts.sum(), num_draws, prec=torch.finfo(counts.dtype).eps)
+
+        # check results is asymptotically as expected.
+        expected = probs * num_draws
+        # ~z is approximately N(0,1) for unbiased count
+        z = (counts - expected) / (expected * (1 - probs)).sqrt()
+        # A (lazy) approximate 99% two-sided test:
+        # occurs with prob alpha~>=0.01 if unbiased
+        self.assertLess(z.abs().max().item(), 2.58)
+
+    def _test_gumbel_softmax_grad(self, device, dtype):
+        # "hard" and "not hard" should propagate same gradient.
+        logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+        logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+
+        seed = torch.random.get_rng_state()
+        y_soft = F.gumbel_softmax(logits_soft, hard=False)
+        torch.random.set_rng_state(seed)
+        y_hard = F.gumbel_softmax(logits_hard, hard=True)
+
+        y_soft.sum().backward()
+        y_hard.sum().backward()
+
+        # 2eps = 1x addition + 1x subtraction.
+        tol = 2 * torch.finfo(dtype).eps
+        self.assertAlmostEqual(logits_soft.grad, logits_hard.grad, delta=tol)
+
+    @dtypesIfCUDA(torch.half, torch.float, torch.double)
+    @dtypes(torch.float, torch.double)
+    def test_gumbel_softmax(self, device, dtype):
+        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=-1, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4], dim=1, count_expected=5)
+        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
+        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
+        self._test_gumbel_softmax_straight_through(device, dtype)
+        self._test_gumbel_softmax_grad(device, dtype)
+
+    def _test_rnn_retain_variables(self, device, dtype):
+        rnns = [nn.LSTM(10, 20, num_layers=2).to(device, dtype),
+                nn.GRU(10, 20, num_layers=2).to(device, dtype),
+                nn.RNN(10, 20, num_layers=2).to(device, dtype)]
+        for rnn in rnns:
+            input = torch.randn(5, 6, 10, device=device, dtype=dtype, requires_grad=True)
+            output = rnn(input)
+            output[0].sum().backward(retain_graph=True)
+            grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
+            for _ in range(4):
+                rnn.zero_grad()
+                input.grad.data.zero_()
+                output[0].sum().backward(retain_graph=True)
+                grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
+                self.assertEqual(grads, grads2)
+
+    @dtypesIfCUDA(torch.half, torch.float, torch.double)
+    @dtypes(torch.double)
+    def test_rnn_retain_variables(self, device, dtype):
+        self._test_rnn_retain_variables(device, dtype)
+
+        if self.device_type == 'cuda' and self.has_cudnn():
+            with torch.backends.cudnn.flags(enabled=False):
+                self._test_rnn_retain_variables(device, dtype)
+
+    @onlyCUDA
+    @skipCUDAIfCudnnVersionLessThan(7000)
+    def test_CTCLoss_cudnn(self, device):
+        target_lengths = [30, 25, 20]
+        input_lengths = [50, 50, 50]
+        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
+        log_probs = torch.randn(50, 3, 15, dtype=torch.float, device=device).log_softmax(2)
+        res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
+        expected = ctcloss_reference(log_probs, targets.cuda(), input_lengths, target_lengths).float()
+        with torch.backends.cudnn.flags(enabled=False):
+            res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths)
+        self.assertEqual(res, expected)
+        self.assertEqual(res2, res)
+
+    @onlyCUDA
+    @skipCUDAIfNoCudnn
+    def test_contig_wrong_stride_cudnn(self, device):
+        # x has to have batch_size 1 to test contiguous checks
+        x = torch.randn(1, 16, 5, 5, device=device)
+        stride = list(x.stride())
+        stride[0] = 20
+        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
+        x.set_(x.storage(), 0, x.size(), stride)
+        self.assertTrue(x.is_contiguous())
+        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
+        F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))
+
+
 
 instantiate_device_type_tests(TestNNDeviceType, globals())