Add a decomposition for take() (#114813)

Presumably this can close https://github.com/pytorch/pytorch/pull/109784

Also related to https://github.com/pytorch/pytorch/issues/93757 (though `take` is not listed there).

There's no bounds checking here (out of bounds indices cause a segfault or undefined behavior). Should that be added somehow?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114813
Approved by: https://github.com/peterbell10, https://github.com/lezcano
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index a292f05..10546ae 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1298,8 +1298,6 @@
 aten::t_
 aten::t_copy
 aten::t_copy.out
-aten::take
-aten::take.out
 aten::to_mkldnn
 aten::to_mkldnn.out
 aten::to_padded_tensor
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 0dc1788..582084a 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -420,6 +420,7 @@
             aten.sum.default,
             aten.sum.out,
             aten.t,
+            aten.take,
             aten.tanh_backward,
             aten.threshold,
             aten.threshold_,
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index ea321ea..2d30d42 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -4420,6 +4420,13 @@
     return x * (y / norm), norm
 
 
+@register_decomposition(aten.take)
+@out_wrapper()
+def take(self, index):
+    flattened = self.reshape(-1)
+    return flattened[index]
+
+
 register_inplace(aten.addbmm_, aten.addbmm)
 register_inplace(aten.addmm_, aten.addmm)
 register_inplace(aten.addmv_, aten.addmv)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index ccad8d3..baf7183 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -2222,7 +2222,6 @@
 make_fallback(aten.special_scaled_modified_bessel_k1)
 make_fallback(aten.special_spherical_bessel_j0, warn=False)
 make_fallback(aten.special_zeta, warn=False)
-make_fallback(aten.take)
 make_fallback(aten._trilinear)
 make_fallback(aten.uniform, warn=False)
 make_fallback(aten._adaptive_avg_pool3d_backward)
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 331d036..aa1e833 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -206,9 +206,9 @@
 
 def out_wrapper(*out_names: str, exact_dtype: bool = False):
     # The wrapped function needs to convert the output parameters to ensure
-    # compatability between the Python API (which always uses "out" as the
+    # compatibility between the Python API (which always uses "out" as the
     # parameter name and may be a tuple) and the Aten API (which may have
-    # multiple output parematers and use different parameter names such as
+    # multiple output parameters and use different parameter names such as
     # "grad_input", "indices" or "values".)
 
     default_out_names = ("out",)