[pt2] convert `out` params in `register_meta` (#101344)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101344
Approved by: https://github.com/lezcano
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 11cafc7..8c618e1 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -54,6 +54,48 @@
             registry[op_overload] = fn
 
 
+def _convert_out_params(f):
+    sig = inspect.signature(f)
+    out_annotation = f.__annotations__.get("out")
+    # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
+    fn = f
+    if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
+        out_names = sig.return_annotation._fields
+        # If out is a tuple, we need to register a function that unpacks all the out
+        # elements as this is what native_functions.yaml expects
+
+        @wraps(f)
+        def _fn(*args, **kwargs):
+            out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
+            # Either all of the out kwargs are set or none of them
+            is_none = out_kwargs[0] is None
+            assert all((o is None) == is_none for o in out_kwargs)
+            return f(*args, **kwargs, out=None if is_none else out_kwargs)
+
+        out_params = [
+            inspect.Parameter(
+                o,
+                kind=inspect.Parameter.KEYWORD_ONLY,
+                default=None,
+                annotation=t,
+            )
+            for o, t in zip(out_names, out_annotation.__args__)
+        ]
+        # Drop the out parameter and concatenate the new kwargs in the signature
+        params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
+        _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
+            parameters=params, return_annotation=sig.return_annotation  # type: ignore[arg-type]
+        )
+        # Drop the out parameter and concatenate the new kwargs in the annotations
+        _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
+        for o in out_params:
+            _fn.__annotations__[o.name] = o.annotation
+
+        fn = _fn
+
+    return fn
+
+
 def register_decomposition(aten_op, registry=None, *, type="post_autograd"):
     """
     A decorator to register a function as a decomposition to the Python
@@ -77,48 +119,8 @@
 
     assert type in {"post_autograd", "pre_autograd", "meta"}
 
-    def decomposition_decorator(f: Callable) -> Callable:
-        sig = inspect.signature(f)
-        out_annotation = f.__annotations__.get("out")
-        # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
-        fn = f
-        if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
-            out_names = sig.return_annotation._fields
-            # If out is a tuple, we need to register a function that unpacks all the out
-            # elements as this is what native_functions.yaml expects
-
-            @wraps(f)
-            def _fn(*args, **kwargs):
-                out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
-                # Either all of the out kwargs are set or none of them
-                is_none = out_kwargs[0] is None
-                assert all((o is None) == is_none for o in out_kwargs)
-                return f(*args, **kwargs, out=None if is_none else out_kwargs)
-
-            out_params = [
-                inspect.Parameter(
-                    o,
-                    kind=inspect.Parameter.KEYWORD_ONLY,
-                    default=None,
-                    annotation=t,
-                )
-                for o, t in zip(out_names, out_annotation.__args__)
-            ]
-            # Drop the out parameter and concatenate the new kwargs in the signature
-            params = chain(
-                (v for k, v in sig.parameters.items() if k != "out"), out_params
-            )
-            _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
-                parameters=params, return_annotation=sig.return_annotation  # type: ignore[arg-type]
-            )
-            # Drop the out parameter and concatenate the new kwargs in the annotations
-            _fn.__annotations__ = {
-                k: v for k, v in f.__annotations__.items() if k != "out"
-            }
-            for o in out_params:
-                _fn.__annotations__[o.name] = o.annotation
-
-            fn = _fn
+    def decomposition_decorator(fn: Callable) -> Callable:
+        fn = _convert_out_params(fn)
 
         nonlocal registry
         if registry is None:
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 11ab86d..3566ae9 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -4,7 +4,12 @@
 import torch
 import torch._prims_common as utils
 from torch import Tensor
-from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
+from torch._decomp import (
+    _add_op_to_registry,
+    _convert_out_params,
+    global_decomposition_table,
+    meta_table,
+)
 from torch._ops import OpOverload
 from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
 from torch._prims_common import (
@@ -36,6 +41,8 @@
 
 def register_meta(op):
     def wrapper(fn):
+        fn = _convert_out_params(fn)
+
         def register(op):
             _add_op_to_registry(meta_table, op, fn)
 
@@ -63,42 +70,20 @@
 
 
 @register_meta([aten.take.default, aten.take.out])
-def meta_take(self, index, *, out=None):
+@out_wrapper()
+def meta_take(self, index):
     # Type and device checks
     check(
         index.dtype == torch.long,
         lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
     )
-    if out is not None:
-        check(
-            self.dtype == out.dtype,
-            lambda: (
-                f"take(): self and out expected to have the same dtype, "
-                f"but got self.dtype = {self.dtype} and out.dtype = {out.dtype}"
-            ),
-        )
-        check(
-            self.device == out.device and self.device == index.device,
-            lambda: (
-                f"take(): self, index and out expected to be in the same device, "
-                f"but got self.device = {self.device}, index.device = {index.device}, "
-                f"and out.device = {out.device}"
-            ),
-        )
-
     # Index checks
     check(
         not (self.numel() == 0 and index.numel() != 0),
         lambda: "take(): tried to take from an empty tensor",
         IndexError,
     )
-
-    result = self.new_empty(index.shape)
-    if out is not None:
-        assert isinstance(out, TensorLike)
-        out = _maybe_resize_out(out, result.shape)
-        return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
-    return result
+    return self.new_empty(index.shape)
 
 
 @register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
@@ -433,13 +418,11 @@
 
 
 @register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
+@out_wrapper("eigenvalues", "eigenvectors")
 def meta__linalg_eigh(
     A: Tensor,
     UPLO: str = "L",
     compute_v: bool = True,
-    *,
-    eigenvalues: Tensor = None,
-    eigenvectors: Tensor = None,
 ):
     squareCheckInputs(A, "linalg.eigh")
     checkUplo(UPLO)
@@ -454,15 +437,6 @@
     shape.pop()
     vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
 
-    if eigenvalues is not None and eigenvectors is not None:
-        assert isinstance(eigenvalues, TensorLike)
-        assert isinstance(eigenvectors, TensorLike)
-        eigenvalues = _maybe_resize_out(eigenvalues, vals.shape)
-        eigenvectors = _maybe_resize_out(eigenvectors, vecs.shape)
-        _safe_copy_out(copy_from=vals, copy_to=eigenvalues)  # type: ignore[arg-type]
-        _safe_copy_out(copy_from=vecs, copy_to=eigenvectors)  # type: ignore[arg-type]
-        return eigenvalues, eigenvectors
-
     return vals, vecs