updated code to ensure error check for negative dims

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31636

Differential Revision: D19233031

Pulled By: anjali411

fbshipit-source-id: c29265ddd1f887f1a0b98aca56a2691d7584353d
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index fed2096..453c559 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -46,8 +46,9 @@
     mask.flip();
   } else {
     for (int64_t dim : dims) {
-      TORCH_CHECK(dim < 64, "PyTorch doesn't support reduction operations for dim>=64");
-      mask.set(maybe_wrap_dim(dim, ndim));
+      int64_t pos_dim = maybe_wrap_dim(dim, ndim);
+      TORCH_CHECK(pos_dim < 64, "PyTorch doesn't support reduction operations for dim>=64");
+      mask.set(pos_dim);
     }
   }
   return mask;
diff --git a/test/test_torch.py b/test/test_torch.py
index a133f65..086746c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -682,6 +682,8 @@
         x = torch.randn(sizes)
         with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"):
             torch.sum(x, 64)
+        with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"):
+            torch.sum(x, -1)
 
     @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
     def test_logsumexp(self):