Dont mutate tensor stride in place in cudnn conv (#126786)

Fix for https://github.com/pytorch/pytorch/issues/126241.

Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126786
Approved by: https://github.com/ezyang, https://github.com/shunting314, https://github.com/eqy
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index 79a2fe5..043658a 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -167,8 +167,9 @@
   void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
 
   void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
-    fixSizeOneDimStride<int>(dim, size, stride, nhwc);
-    AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
+    std::vector<int> strides_copy(stride, stride + dim);
+    fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
+    AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
   }
 };
 
diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp
index 750cbcc..4bd7273 100644
--- a/aten/src/ATen/native/cudnn/Conv_v8.cpp
+++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp
@@ -90,11 +90,13 @@
   auto strides = t.strides();
   bool channels_last = memory_format == at::MemoryFormat::ChannelsLast ||
       memory_format == at::MemoryFormat::ChannelsLast3d;
+
+  std::vector<int64_t> strides_copy(std::begin(strides), std::end(strides));
   fixSizeOneDimStride<int64_t>(
-      sizes.size(), &sizes[0], (int64_t*)&strides[0], channels_last);
+      sizes.size(), &sizes[0], (int64_t*)&strides_copy[0], channels_last);
   auto r = cudnn_frontend::TensorBuilder()
                .setDim(sizes.size(), sizes.data())
-               .setStrides(strides.size(), strides.data())
+               .setStrides(strides_copy.size(), strides_copy.data())
                .setId(id)
                .setAlignment(alignment)
                .setDataType(dataType)
diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py
index c0d7150..c030d07 100644
--- a/test/nn/test_convolution.py
+++ b/test/nn/test_convolution.py
@@ -528,6 +528,37 @@
 
     @unittest.skipIf(not TEST_CUDA, "CUDA not available")
     @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+    def test_cudnn_not_mutate_stride(self):
+        weight = torch.randn(64, 64, 1, 1)
+        x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
+        weight_stride = weight.stride()
+
+        def conv(x, weight):
+            return torch.convolution(
+                x,
+                weight,
+                stride=(1, 1),
+                padding=(0, 0),
+                dilation=(1, 1),
+                transposed=False,
+                output_padding=(0, 0),
+                groups=1,
+                bias=None,
+            )
+
+        # should have run in nhwc without mutating input strides
+        out_nhwc = conv(x, weight)
+        self.assertEqual(weight.stride(), weight_stride)
+        self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))
+
+        x = x.contiguous(memory_format=torch.contiguous_format)
+        out_c = conv(x, weight)
+        self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
+        self.assertEqual(out_c, out_nhwc)
+        self.assertEqual(weight.stride(), weight_stride)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
+    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
     def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
         inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
         weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")