Fix edge case for size 1 channels dim in AdaptiveMaxPool (#116482)
Fixes https://github.com/pytorch/pytorch/issues/107842
Unlike `AdaptiveAvgPool`, `AdaptiveMaxPool` does not have a CUDA kernel for ChannelsLast. We workaround this by calling `contiguous()` on the input. However, there is an edge case when the channels dimension has size 1.
```python
>>> t = torch.randn(2, 1, 3, 3)
>>> t.stride()
(9, 9, 3, 1)
>>> t_c = t.to(memory_format=torch.channels_last)
>>> t_c.stride()
(9, 1, 3, 1) # (CHW, 1, CW, C)
>>> t_c.is_contiguous()
True # contiguity check doesn't check strides for singleton dimensions
```
Since the CUDA kernel treats the batch,`B`, and channels,`C`, dimensions as implicitly flattened and increments the data pointer for `input` to the start of the next plane using
https://github.com/pytorch/pytorch/blob/669b182d33f5e1368d8de6d86b891f65480c9b22/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu#L67
If our input falls into the aforementioned edge case, the `data_ptr` will not be incremented correctly. The simple fix for this is to calculate the stride for the channels dimension using $\prod_{i > 1}size(i)$
Analogous fix for the 3D case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116482
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
index d8fab31..2030dbb 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
@@ -268,7 +268,11 @@
int64_t isizeH = input_.size(2);
int64_t isizeW = input_.size(3);
- int64_t istrideD = input_.stride(1);
+ // In the kernel, the batch and channel dimensions are treated as if they
+ // are flattened and istrideD is used as the stride of this flattened dim
+ // Handle the edge case where input_.size(1) == 1, where despite passing the
+ // contiguity check the stride might not be H * W
+ int64_t istrideD = isizeH * isizeW;
int64_t istrideH = input_.stride(2);
int64_t istrideW = input_.stride(3);
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
index 06a4cf6..ca99689 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
@@ -346,7 +346,11 @@
isizeH = input_.size(3);
isizeW = input_.size(4);
- istrideD = input_.stride(1);
+ // In the kernel, the batch and channel dimensions are treated as if they
+ // are flattened and istrideD is used as the stride of this flattened dim
+ // Handle the edge case where input_.size(1) == 1, where despite passing the
+ // contiguity check the stride might not be T * H * W
+ istrideD = isizeT * isizeH * isizeW;
istrideT = input_.stride(2);
istrideH = input_.stride(3);
istrideW = input_.stride(4);
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index 392ff81..d73b7ea 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -1072,38 +1072,48 @@
@dtypes(torch.float, torch.double)
def test_adaptive_pooling_max_nhwc(self, device, dtype):
- def helper(n, c, h, w, output_height, output_width, contig):
- input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
- input = input.contiguous(memory_format=torch.channels_last)
- grad = torch.randint(1, 10, (4, 8, output_height, output_width), device=device, dtype=dtype)
- grad = grad.contiguous(memory_format=torch.channels_last)
+ def helper(input_size, output_plane_size, contig):
+ n_plane_dims = len(output_plane_size)
+ mod = torch.nn.AdaptiveMaxPool2d if n_plane_dims == 2 else torch.nn.AdaptiveMaxPool3d
+ channels_last = torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
+ output_size = input_size[:2] + output_plane_size
+ input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
+ input = input.contiguous(memory_format=channels_last)
+ grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
+ grad = grad.contiguous(memory_format=channels_last)
if not contig:
- input = input[:, ::2, :, :]
- grad = grad[:, ::2, :, :]
+ input = input[:, ::2]
+ grad = grad[:, ::2]
input.requires_grad_(True)
- pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
+ pool = mod(output_plane_size, return_indices=True).to(device)
ref_input = input.detach().clone().contiguous().requires_grad_(True)
ref_grad = grad.detach().clone().contiguous()
- ref_pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
+ ref_pool = mod(output_plane_size, return_indices=True).to(device)
out, ind = pool(input)
out.backward(grad)
ref_out, ref_ind = ref_pool(ref_input)
ref_out.backward(ref_grad)
- self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
+ # channels_last_3d case does not return channels_last_3d outputs
+ if n_plane_dims == 2:
+ self.assertTrue(out.is_contiguous(memory_format=channels_last))
+ self.assertTrue(ind.is_contiguous(memory_format=channels_last))
self.assertTrue(ref_out.is_contiguous())
- self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_ind.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(ind, ref_ind)
self.assertEqual(input.grad, ref_input.grad)
for contig in [True, False]:
- helper(4, 8, 10, 10, 7, 7, contig)
- helper(4, 8, 9, 14, 5, 8, contig)
- helper(4, 8, 11, 11, 1, 1, contig)
+ helper((4, 8, 10, 10), (7, 7), contig)
+ helper((4, 8, 9, 14), (5, 8), contig)
+ helper((4, 8, 11, 11), (1, 1), contig)
+ helper((2, 1, 3, 3), (1, 1), contig)
+ helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
+ helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
+ helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
@dtypes(torch.float, torch.double)
def test_pooling_max_nhwc(self, device, dtype):