Add a decomposition for take() (#114813)
Presumably this can close https://github.com/pytorch/pytorch/pull/109784
Also related to https://github.com/pytorch/pytorch/issues/93757 (though `take` is not listed there).
There's no bounds checking here (out of bounds indices cause a segfault or undefined behavior). Should that be added somehow?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114813
Approved by: https://github.com/peterbell10, https://github.com/lezcano
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index a292f05..10546ae 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1298,8 +1298,6 @@
aten::t_
aten::t_copy
aten::t_copy.out
-aten::take
-aten::take.out
aten::to_mkldnn
aten::to_mkldnn.out
aten::to_padded_tensor
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 0dc1788..582084a 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -420,6 +420,7 @@
aten.sum.default,
aten.sum.out,
aten.t,
+ aten.take,
aten.tanh_backward,
aten.threshold,
aten.threshold_,
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index ea321ea..2d30d42 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -4420,6 +4420,13 @@
return x * (y / norm), norm
+@register_decomposition(aten.take)
+@out_wrapper()
+def take(self, index):
+ flattened = self.reshape(-1)
+ return flattened[index]
+
+
register_inplace(aten.addbmm_, aten.addbmm)
register_inplace(aten.addmm_, aten.addmm)
register_inplace(aten.addmv_, aten.addmv)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index ccad8d3..baf7183 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -2222,7 +2222,6 @@
make_fallback(aten.special_scaled_modified_bessel_k1)
make_fallback(aten.special_spherical_bessel_j0, warn=False)
make_fallback(aten.special_zeta, warn=False)
-make_fallback(aten.take)
make_fallback(aten._trilinear)
make_fallback(aten.uniform, warn=False)
make_fallback(aten._adaptive_avg_pool3d_backward)
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 331d036..aa1e833 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -206,9 +206,9 @@
def out_wrapper(*out_names: str, exact_dtype: bool = False):
# The wrapped function needs to convert the output parameters to ensure
- # compatability between the Python API (which always uses "out" as the
+ # compatibility between the Python API (which always uses "out" as the
# parameter name and may be a tuple) and the Aten API (which may have
- # multiple output parematers and use different parameter names such as
+ # multiple output parameters and use different parameter names such as
# "grad_input", "indices" or "values".)
default_out_names = ("out",)