ONNX support for torch.take (#33061)
Summary:
Adding ONNX export support for torch.take()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33061
Reviewed By: hl475
Differential Revision: D19782651
Pulled By: houseroad
fbshipit-source-id: 0168fb941e166acda4ca607165248b8e0b260ace
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 0d38fda..18d8e24 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1295,6 +1295,15 @@
base = 1
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
+ def test_take(self):
+ class TakeModel(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.take(x, y)
+
+ x = torch.randn(6, 4, 3, 3)
+ y = torch.tensor([4, 1, 7, 15, 63])
+ self.run_test(TakeModel(), (x, y))
+
def test_topk(self):
class MyModule(torch.nn.Module):
def forward(self, x):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 639768c..3e8ac54 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2213,3 +2213,9 @@
def __getitem_(g, self, i):
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
+
+def take(g, self, index):
+ self_flattened = g.op('Reshape', self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
+ out = index_select(g, self_flattened, 0, index)
+ out = reshape_as(g, out, index)
+ return out