ONNX support for torch.take (#33061)

Summary:
Adding ONNX export support for torch.take()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33061

Reviewed By: hl475

Differential Revision: D19782651

Pulled By: houseroad

fbshipit-source-id: 0168fb941e166acda4ca607165248b8e0b260ace
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 0d38fda..18d8e24 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1295,6 +1295,15 @@
         base = 1
         self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
 
+    def test_take(self):
+        class TakeModel(torch.nn.Module):
+            def forward(self, x, y):
+                return torch.take(x, y)
+
+        x = torch.randn(6, 4, 3, 3)
+        y = torch.tensor([4, 1, 7, 15, 63])
+        self.run_test(TakeModel(), (x, y))
+
     def test_topk(self):
         class MyModule(torch.nn.Module):
             def forward(self, x):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 639768c..3e8ac54 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2213,3 +2213,9 @@
 
 def __getitem_(g, self, i):
     return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
+
+def take(g, self, index):
+    self_flattened = g.op('Reshape', self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
+    out = index_select(g, self_flattened, 0, index)
+    out = reshape_as(g, out, index)
+    return out