[pytorch][te] Handle negative axis in chunk (#48084)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48084

as title
ghstack-source-id: 116870328

Test Plan: new unit test

Reviewed By: Krovatkin

Differential Revision: D25017489

fbshipit-source-id: 0d1998fccad6f509db04b6c67a4e4e4093d96751
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 30e5e60..1eb7a33 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -1347,6 +1347,14 @@
         t = torch.rand(8, dtype=torch.float, device='cuda')
         scripted = self.checkScript(eager, (t, t, t, t, 0.1))
 
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    def test_chunk_mul_one(self):
+        def eager(x):
+            z, y, w = torch.chunk(x, 3, -1)
+            return z*3, y, w
+        x = torch.rand(64, 1, 3072, dtype=torch.float, device='cuda')
+        script = self.checkScript(eager, (x,))
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 83ecb69..fd14f55 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -129,9 +129,12 @@
 ExprHandle TensorExprKernel::chunk(
     Tensor* t,
     size_t chunkIdx,
-    size_t dim,
-    size_t chunks,
+    int64_t dim,
+    int64_t chunks,
     const std::vector<ExprHandle>& axes) {
+  if (dim < 0) {
+    dim = axes.size() + dim;
+  }
   auto sizes = bufferSizes(t);
   size_t step = sizes[dim] / chunks;
 
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index c6df7f4..54a876e 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -65,8 +65,8 @@
   ExprHandle chunk(
       Tensor* t,
       size_t chunkIdx,
-      size_t dim,
-      size_t chunks,
+      int64_t dim,
+      int64_t chunks,
       const std::vector<ExprHandle>& axes);
 
   std::vector<ExprHandle> valueShape(const torch::jit::Value* v);