Ellipsis in subscript
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17763
Differential Revision: D14893533
Pulled By: Krovatkin
fbshipit-source-id: c46b4e386d3aa30e6dc03e3052d2e5ff097fa74b
diff --git a/test/test_jit.py b/test/test_jit.py
index 47901d4..ff4a9b3 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8124,6 +8124,36 @@
self.assertEqual(8, bar(torch.ones(1, 1)))
+ def test_ellipsis_mid(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[2, ..., 0:4, 4:8].size()
+
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_mid_select(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[2, ..., 4, 4, 4:8, 2].size()
+
+ dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_start(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[..., 0:4, 4:8].size()
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_end(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[0:4, 2, ...].size()
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
def test_tracing_slicing(self):
@_trace(torch.zeros(10))
def foo_trace(x):