fix: max_unpool3d buffer overflow (#94372)
Fixes #88032
Previously `output_size` is accessed before the shape length check, which leads to a buffer overflow issue.
The fix is simply to prioritize the check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94372
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp
index adab802..3ba0c3c 100644
--- a/aten/src/ATen/native/MaxUnpooling.cpp
+++ b/aten/src/ATen/native/MaxUnpooling.cpp
@@ -84,9 +84,7 @@
IntArrayRef stride,
IntArrayRef padding,
const char *fn_name) {
- int64_t oT = output_size[0];
- int64_t oH = output_size[1];
- int64_t oW = output_size[2];
+
TORCH_CHECK(
indices.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64");
@@ -118,6 +116,10 @@
"strides should be greater than zero, but got stride: ",
stride);
+ int64_t oT = output_size[0];
+ int64_t oH = output_size[1];
+ int64_t oW = output_size[2];
+
int dimw = 3;
int dimh = 2;
int dimt = 1;
@@ -167,9 +169,6 @@
at::globalContext().alertNotDeterministic("max_unpooling3d_forward_out");
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
- int64_t oT = output_size[0];
- int64_t oH = output_size[1];
- int64_t oW = output_size[2];
auto self = self_.contiguous();
auto indices = indices_.contiguous();
@@ -177,6 +176,10 @@
max_unpooling3d_shape_check(
self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");
+ int64_t oT = output_size[0];
+ int64_t oH = output_size[1];
+ int64_t oW = output_size[2];
+
if (self_.ndimension() == 5) {
output.resize_({self.size(0), self.size(1), oT, oH, oW});
} else {
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index df8eb59..e795d6b 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -353,6 +353,10 @@
self.assertEqual(F.max_unpool3d(output, indices, 2), F.max_unpool3d(output, indices, 2, stride=2))
gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)
+ def test_max_unpool3d_input_check(self):
+ x = torch.ones(1, 3, 1, 1, 1)
+ with self.assertRaises(RuntimeError):
+ F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
class TestPoolingNNDeviceType(NNTestCase):
@onlyNativeDeviceTypes