[CUDNN NHWC CONVOLUTION] Re-stride input tensors for wgrad in cudnn_convolution (#33784)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33784

Differential Revision: D20127485

Pulled By: VitalyFedyunin

fbshipit-source-id: 9d893ffe7ff9499e7e9a7e8bed720e9441d1018e
diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h
index b313611..274b5cc 100644
--- a/aten/src/ATen/native/ConvUtils.h
+++ b/aten/src/ATen/native/ConvUtils.h
@@ -82,7 +82,10 @@
 }
 
 static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
-  if (!detail::getCUDAHooks().compiledWithCuDNN()) {
+  // disable NHWC for float64 input.
+  if (!detail::getCUDAHooks().compiledWithCuDNN() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
     return false;
   }
   long cudnn_version = detail::getCUDAHooks().versionCuDNN();
diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp
index 82b76e6..d600298 100644
--- a/aten/src/ATen/native/cudnn/Conv.cpp
+++ b/aten/src/ATen/native/cudnn/Conv.cpp
@@ -1194,8 +1194,6 @@
     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
     bool benchmark, bool deterministic)
 {
-  TensorArg input{ input_t, "input", 2};
-
   auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ?
       at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
 
@@ -1203,6 +1201,10 @@
   // Make sure that NC11 strides follow formula
   grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout);
   TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
+ 
+  Tensor input_contig_t = input_t.contiguous(layout);
+  input_contig_t.resize_(input_contig_t.sizes(), layout);
+  TensorArg input{ input_contig_t, "input", 2};
 
   checkAllSameType(c, {grad_output_contig, input});
   checkAllSameGPU(c, {grad_output_contig, input});
diff --git a/test/test_nn.py b/test/test_nn.py
index 2e0cf0f..64c94ab 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10835,6 +10835,22 @@
             self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device)
             self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device)
 
+    # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
+    # returning CUDNN_STATUS_BAD_PARAM
+    # Disabling that specific test for now [see issue # 33918]
+    @onlyCUDA
+    @skipCUDAIfRocm
+    @skipCUDAIfNoCudnn
+    @dtypes(torch.float, torch.double)
+    def test_conv_cudnn_nhwc_support(self, device, dtype):
+        input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True)
+        weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True)
+        weight = weight.to(memory_format=torch.channels_last)
+        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
+        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
+        o.sum().backward()
+
+
     @onlyCUDA
     @skipCUDAIfRocm
     @skipCUDAIfCudnnVersionLessThan(7603)