[pytorch][te] Handle negative axis in chunk (#48084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48084
as title
ghstack-source-id: 116870328
Test Plan: new unit test
Reviewed By: Krovatkin
Differential Revision: D25017489
fbshipit-source-id: 0d1998fccad6f509db04b6c67a4e4e4093d96751
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 30e5e60..1eb7a33 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -1347,6 +1347,14 @@
t = torch.rand(8, dtype=torch.float, device='cuda')
scripted = self.checkScript(eager, (t, t, t, t, 0.1))
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_chunk_mul_one(self):
+ def eager(x):
+ z, y, w = torch.chunk(x, 3, -1)
+ return z*3, y, w
+ x = torch.rand(64, 1, 3072, dtype=torch.float, device='cuda')
+ script = self.checkScript(eager, (x,))
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 83ecb69..fd14f55 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -129,9 +129,12 @@
ExprHandle TensorExprKernel::chunk(
Tensor* t,
size_t chunkIdx,
- size_t dim,
- size_t chunks,
+ int64_t dim,
+ int64_t chunks,
const std::vector<ExprHandle>& axes) {
+ if (dim < 0) {
+ dim = axes.size() + dim;
+ }
auto sizes = bufferSizes(t);
size_t step = sizes[dim] / chunks;
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index c6df7f4..54a876e 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -65,8 +65,8 @@
ExprHandle chunk(
Tensor* t,
size_t chunkIdx,
- size_t dim,
- size_t chunks,
+ int64_t dim,
+ int64_t chunks,
const std::vector<ExprHandle>& axes);
std::vector<ExprHandle> valueShape(const torch::jit::Value* v);