[CUDNN NHWC CONVOLUTION] Re-stride input tensors for wgrad in cudnn_convolution (#33784)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33784
Differential Revision: D20127485
Pulled By: VitalyFedyunin
fbshipit-source-id: 9d893ffe7ff9499e7e9a7e8bed720e9441d1018e
diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h
index b313611..274b5cc 100644
--- a/aten/src/ATen/native/ConvUtils.h
+++ b/aten/src/ATen/native/ConvUtils.h
@@ -82,7 +82,10 @@
}
static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
- if (!detail::getCUDAHooks().compiledWithCuDNN()) {
+ // disable NHWC for float64 input.
+ if (!detail::getCUDAHooks().compiledWithCuDNN() ||
+ input.scalar_type() == at::kDouble ||
+ weight.scalar_type() == at::kDouble) {
return false;
}
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp
index 82b76e6..d600298 100644
--- a/aten/src/ATen/native/cudnn/Conv.cpp
+++ b/aten/src/ATen/native/cudnn/Conv.cpp
@@ -1194,8 +1194,6 @@
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
- TensorArg input{ input_t, "input", 2};
-
auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ?
at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
@@ -1203,6 +1201,10 @@
// Make sure that NC11 strides follow formula
grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout);
TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
+
+ Tensor input_contig_t = input_t.contiguous(layout);
+ input_contig_t.resize_(input_contig_t.sizes(), layout);
+ TensorArg input{ input_contig_t, "input", 2};
checkAllSameType(c, {grad_output_contig, input});
checkAllSameGPU(c, {grad_output_contig, input});
diff --git a/test/test_nn.py b/test/test_nn.py
index 2e0cf0f..64c94ab 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10835,6 +10835,22 @@
self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device)
self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device)
+ # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
+ # returning CUDNN_STATUS_BAD_PARAM
+ # Disabling that specific test for now [see issue # 33918]
+ @onlyCUDA
+ @skipCUDAIfRocm
+ @skipCUDAIfNoCudnn
+ @dtypes(torch.float, torch.double)
+ def test_conv_cudnn_nhwc_support(self, device, dtype):
+ input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True)
+ weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True)
+ weight = weight.to(memory_format=torch.channels_last)
+ o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
+ self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
+ o.sum().backward()
+
+
@onlyCUDA
@skipCUDAIfRocm
@skipCUDAIfCudnnVersionLessThan(7603)