Fix edge case for size 1 channels dim in AdaptiveMaxPool (#116482)

Fixes https://github.com/pytorch/pytorch/issues/107842

Unlike `AdaptiveAvgPool`, `AdaptiveMaxPool` does not have a CUDA kernel for ChannelsLast. We workaround this by calling `contiguous()` on the input. However, there is an edge case when the channels dimension has size 1.

```python
>>> t = torch.randn(2, 1, 3, 3)
>>> t.stride()
(9, 9, 3, 1)
>>> t_c =  t.to(memory_format=torch.channels_last)
>>> t_c.stride()
(9, 1, 3, 1)  # (CHW, 1, CW, C)
>>> t_c.is_contiguous()
True  # contiguity check doesn't check strides for singleton dimensions
```

Since the CUDA kernel treats the batch,`B`, and  channels,`C`, dimensions as implicitly flattened and increments the data pointer for `input` to the start of the next plane using

https://github.com/pytorch/pytorch/blob/669b182d33f5e1368d8de6d86b891f65480c9b22/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu#L67

If our input falls into the aforementioned edge case, the `data_ptr` will not be incremented correctly. The simple fix for this is to calculate the stride for the channels dimension using $\prod_{i > 1}size(i)$

Analogous fix for the 3D case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116482
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
index d8fab31..2030dbb 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
@@ -268,7 +268,11 @@
     int64_t isizeH = input_.size(2);
     int64_t isizeW = input_.size(3);
 
-    int64_t istrideD = input_.stride(1);
+    // In the kernel, the batch and channel dimensions are treated as if they
+    // are flattened and istrideD is used as the stride of this flattened dim
+    // Handle the edge case where input_.size(1) == 1, where despite passing the
+    // contiguity check the stride might not be H * W
+    int64_t istrideD = isizeH * isizeW;
     int64_t istrideH = input_.stride(2);
     int64_t istrideW = input_.stride(3);
 
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
index 06a4cf6..ca99689 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
@@ -346,7 +346,11 @@
     isizeH = input_.size(3);
     isizeW = input_.size(4);
 
-    istrideD = input_.stride(1);
+    // In the kernel, the batch and channel dimensions are treated as if they
+    // are flattened and istrideD is used as the stride of this flattened dim
+    // Handle the edge case where input_.size(1) == 1, where despite passing the
+    // contiguity check the stride might not be T * H * W
+    istrideD = isizeT * isizeH * isizeW;
     istrideT = input_.stride(2);
     istrideH = input_.stride(3);
     istrideW = input_.stride(4);
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index 392ff81..d73b7ea 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -1072,38 +1072,48 @@
 
     @dtypes(torch.float, torch.double)
     def test_adaptive_pooling_max_nhwc(self, device, dtype):
-        def helper(n, c, h, w, output_height, output_width, contig):
-            input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
-            input = input.contiguous(memory_format=torch.channels_last)
-            grad = torch.randint(1, 10, (4, 8, output_height, output_width), device=device, dtype=dtype)
-            grad = grad.contiguous(memory_format=torch.channels_last)
+        def helper(input_size, output_plane_size, contig):
+            n_plane_dims = len(output_plane_size)
+            mod = torch.nn.AdaptiveMaxPool2d if n_plane_dims == 2 else torch.nn.AdaptiveMaxPool3d
+            channels_last = torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
+            output_size = input_size[:2] + output_plane_size
+            input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
+            input = input.contiguous(memory_format=channels_last)
+            grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
+            grad = grad.contiguous(memory_format=channels_last)
             if not contig:
-                input = input[:, ::2, :, :]
-                grad = grad[:, ::2, :, :]
+                input = input[:, ::2]
+                grad = grad[:, ::2]
             input.requires_grad_(True)
-            pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
+            pool = mod(output_plane_size, return_indices=True).to(device)
 
             ref_input = input.detach().clone().contiguous().requires_grad_(True)
             ref_grad = grad.detach().clone().contiguous()
-            ref_pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
+            ref_pool = mod(output_plane_size, return_indices=True).to(device)
 
             out, ind = pool(input)
             out.backward(grad)
             ref_out, ref_ind = ref_pool(ref_input)
             ref_out.backward(ref_grad)
 
-            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
+            # channels_last_3d case does not return channels_last_3d outputs
+            if n_plane_dims == 2:
+                self.assertTrue(out.is_contiguous(memory_format=channels_last))
+                self.assertTrue(ind.is_contiguous(memory_format=channels_last))
             self.assertTrue(ref_out.is_contiguous())
-            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
             self.assertTrue(ref_ind.is_contiguous())
             self.assertEqual(out, ref_out)
             self.assertEqual(ind, ref_ind)
             self.assertEqual(input.grad, ref_input.grad)
 
         for contig in [True, False]:
-            helper(4, 8, 10, 10, 7, 7, contig)
-            helper(4, 8, 9, 14, 5, 8, contig)
-            helper(4, 8, 11, 11, 1, 1, contig)
+            helper((4, 8, 10, 10), (7, 7), contig)
+            helper((4, 8, 9, 14), (5, 8), contig)
+            helper((4, 8, 11, 11), (1, 1), contig)
+            helper((2, 1, 3, 3), (1, 1), contig)
+            helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
+            helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
+            helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
 
     @dtypes(torch.float, torch.double)
     def test_pooling_max_nhwc(self, device, dtype):