[pt2] convert `out` params in `register_meta` (#101344)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101344
Approved by: https://github.com/lezcano
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 11cafc7..8c618e1 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -54,6 +54,48 @@
registry[op_overload] = fn
+def _convert_out_params(f):
+ sig = inspect.signature(f)
+ out_annotation = f.__annotations__.get("out")
+ # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
+ fn = f
+ if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
+ out_names = sig.return_annotation._fields
+ # If out is a tuple, we need to register a function that unpacks all the out
+ # elements as this is what native_functions.yaml expects
+
+ @wraps(f)
+ def _fn(*args, **kwargs):
+ out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
+ # Either all of the out kwargs are set or none of them
+ is_none = out_kwargs[0] is None
+ assert all((o is None) == is_none for o in out_kwargs)
+ return f(*args, **kwargs, out=None if is_none else out_kwargs)
+
+ out_params = [
+ inspect.Parameter(
+ o,
+ kind=inspect.Parameter.KEYWORD_ONLY,
+ default=None,
+ annotation=t,
+ )
+ for o, t in zip(out_names, out_annotation.__args__)
+ ]
+ # Drop the out parameter and concatenate the new kwargs in the signature
+ params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
+ parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
+ )
+ # Drop the out parameter and concatenate the new kwargs in the annotations
+ _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
+ for o in out_params:
+ _fn.__annotations__[o.name] = o.annotation
+
+ fn = _fn
+
+ return fn
+
+
def register_decomposition(aten_op, registry=None, *, type="post_autograd"):
"""
A decorator to register a function as a decomposition to the Python
@@ -77,48 +119,8 @@
assert type in {"post_autograd", "pre_autograd", "meta"}
- def decomposition_decorator(f: Callable) -> Callable:
- sig = inspect.signature(f)
- out_annotation = f.__annotations__.get("out")
- # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
- fn = f
- if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
- out_names = sig.return_annotation._fields
- # If out is a tuple, we need to register a function that unpacks all the out
- # elements as this is what native_functions.yaml expects
-
- @wraps(f)
- def _fn(*args, **kwargs):
- out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
- # Either all of the out kwargs are set or none of them
- is_none = out_kwargs[0] is None
- assert all((o is None) == is_none for o in out_kwargs)
- return f(*args, **kwargs, out=None if is_none else out_kwargs)
-
- out_params = [
- inspect.Parameter(
- o,
- kind=inspect.Parameter.KEYWORD_ONLY,
- default=None,
- annotation=t,
- )
- for o, t in zip(out_names, out_annotation.__args__)
- ]
- # Drop the out parameter and concatenate the new kwargs in the signature
- params = chain(
- (v for k, v in sig.parameters.items() if k != "out"), out_params
- )
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
- parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
- )
- # Drop the out parameter and concatenate the new kwargs in the annotations
- _fn.__annotations__ = {
- k: v for k, v in f.__annotations__.items() if k != "out"
- }
- for o in out_params:
- _fn.__annotations__[o.name] = o.annotation
-
- fn = _fn
+ def decomposition_decorator(fn: Callable) -> Callable:
+ fn = _convert_out_params(fn)
nonlocal registry
if registry is None:
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 11ab86d..3566ae9 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -4,7 +4,12 @@
import torch
import torch._prims_common as utils
from torch import Tensor
-from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
+from torch._decomp import (
+ _add_op_to_registry,
+ _convert_out_params,
+ global_decomposition_table,
+ meta_table,
+)
from torch._ops import OpOverload
from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
from torch._prims_common import (
@@ -36,6 +41,8 @@
def register_meta(op):
def wrapper(fn):
+ fn = _convert_out_params(fn)
+
def register(op):
_add_op_to_registry(meta_table, op, fn)
@@ -63,42 +70,20 @@
@register_meta([aten.take.default, aten.take.out])
-def meta_take(self, index, *, out=None):
+@out_wrapper()
+def meta_take(self, index):
# Type and device checks
check(
index.dtype == torch.long,
lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
)
- if out is not None:
- check(
- self.dtype == out.dtype,
- lambda: (
- f"take(): self and out expected to have the same dtype, "
- f"but got self.dtype = {self.dtype} and out.dtype = {out.dtype}"
- ),
- )
- check(
- self.device == out.device and self.device == index.device,
- lambda: (
- f"take(): self, index and out expected to be in the same device, "
- f"but got self.device = {self.device}, index.device = {index.device}, "
- f"and out.device = {out.device}"
- ),
- )
-
# Index checks
check(
not (self.numel() == 0 and index.numel() != 0),
lambda: "take(): tried to take from an empty tensor",
IndexError,
)
-
- result = self.new_empty(index.shape)
- if out is not None:
- assert isinstance(out, TensorLike)
- out = _maybe_resize_out(out, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
- return result
+ return self.new_empty(index.shape)
@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
@@ -433,13 +418,11 @@
@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
+@out_wrapper("eigenvalues", "eigenvectors")
def meta__linalg_eigh(
A: Tensor,
UPLO: str = "L",
compute_v: bool = True,
- *,
- eigenvalues: Tensor = None,
- eigenvectors: Tensor = None,
):
squareCheckInputs(A, "linalg.eigh")
checkUplo(UPLO)
@@ -454,15 +437,6 @@
shape.pop()
vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
- if eigenvalues is not None and eigenvectors is not None:
- assert isinstance(eigenvalues, TensorLike)
- assert isinstance(eigenvectors, TensorLike)
- eigenvalues = _maybe_resize_out(eigenvalues, vals.shape)
- eigenvectors = _maybe_resize_out(eigenvectors, vecs.shape)
- _safe_copy_out(copy_from=vals, copy_to=eigenvalues) # type: ignore[arg-type]
- _safe_copy_out(copy_from=vecs, copy_to=eigenvectors) # type: ignore[arg-type]
- return eigenvalues, eigenvectors
-
return vals, vecs