updated code to ensure error check for negative dims
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31636
Differential Revision: D19233031
Pulled By: anjali411
fbshipit-source-id: c29265ddd1f887f1a0b98aca56a2691d7584353d
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index fed2096..453c559 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -46,8 +46,9 @@
mask.flip();
} else {
for (int64_t dim : dims) {
- TORCH_CHECK(dim < 64, "PyTorch doesn't support reduction operations for dim>=64");
- mask.set(maybe_wrap_dim(dim, ndim));
+ int64_t pos_dim = maybe_wrap_dim(dim, ndim);
+ TORCH_CHECK(pos_dim < 64, "PyTorch doesn't support reduction operations for dim>=64");
+ mask.set(pos_dim);
}
}
return mask;
diff --git a/test/test_torch.py b/test/test_torch.py
index a133f65..086746c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -682,6 +682,8 @@
x = torch.randn(sizes)
with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"):
torch.sum(x, 64)
+ with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"):
+ torch.sum(x, -1)
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_logsumexp(self):