| import torch |
| from torch._prims.utils import ( |
| Number, |
| NumberType, |
| TensorLike, |
| TensorLikeType, |
| ELEMENTWISE_TYPE_PROMOTION_KIND, |
| ) |
| import torch._prims.utils as utils |
| from torch.utils._pytree import tree_flatten |
| |
| from typing import Callable, Sequence, Union |
| import inspect |
| from functools import wraps, reduce |
| import operator |
| import warnings |
| from itertools import chain |
| |
| # TODO: implement ref.cast with an option to enforce safe casting |
| def _maybe_convert_to_dtype( |
| a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype |
| ) -> Union[TensorLikeType, NumberType, Sequence]: |
| if isinstance(a, TensorLike): |
| if a.dtype != dtype: |
| # NOTE: this is incorrect on the CPU |
| # See https://github.com/pytorch/pytorch/issues/77553 |
| return prims.convert_element_type(a, dtype) |
| return a |
| if isinstance(a, Number): |
| return utils.dtype_to_type(dtype)(a) |
| if isinstance(a, Sequence): |
| return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) |
| |
| raise ValueError( |
| "Received type {0} that is neither a tensor or a number!".format(type(a)) |
| ) |
| |
| |
| def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: |
| if not isinstance(a, Number): |
| msg = "Found unknown type {0} when trying to convert scalars!".format(type(a)) |
| raise ValueError(msg) |
| if not utils.is_weakly_lesser_type(type(a), typ): |
| msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format( |
| a, type(a), typ |
| ) |
| raise ValueError(msg) |
| |
| return typ(a) |
| |
| |
| def _annotation_has_type(*, typ, annotation): |
| if hasattr(annotation, "__args__"): |
| for a in annotation.__args__: |
| if _annotation_has_type(typ=typ, annotation=a): |
| return True |
| return False |
| |
| return typ is annotation |
| |
| |
| class elementwise_type_promotion_wrapper(object): |
| """ |
| Adds elementwise type promotion to a Python reference implementation. |
| |
| Takes two kwargs, type_promoting_args and type_promotion_kind. |
| |
| type_promoting_args must be a string Sequence specifiying the argument names of all |
| arguments that participate in type promotion (and should be type promoted). If the |
| arg specifies a Sequence-type then every element of the Sequence will participate in |
| type promotion. |
| |
| type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. |
| See its documentation for details. |
| |
| Other type promotion behavior, like validating the Python type of scalar arguments, must |
| be handled separately. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, |
| type_promoting_args: Sequence[str] = None, |
| ): |
| self.type_promoting_arg_names = type_promoting_args |
| self.type_promotion_kind = type_promotion_kind |
| |
| def __call__(self, fn: Callable) -> Callable: |
| sig = inspect.signature(fn) |
| |
| @wraps(fn) |
| def _fn(*args, **kwargs): |
| bound = sig.bind(*args, **kwargs) |
| type_promoting_args = tuple( |
| bound.arguments[x] |
| for x in self.type_promoting_arg_names # type: ignore[union-attr] |
| if x in bound.arguments.keys() |
| ) |
| |
| flattened_type_promoting_args = tree_flatten(type_promoting_args)[0] |
| compute_dtype, result_dtype = utils.elementwise_dtypes( |
| *flattened_type_promoting_args, |
| type_promotion_kind=self.type_promotion_kind, |
| ) |
| |
| promoted_args = { |
| x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) |
| for x in self.type_promoting_arg_names # type: ignore[union-attr] |
| if x in bound.arguments.keys() |
| } |
| bound.arguments.update(promoted_args) |
| |
| result = fn(**bound.arguments) |
| |
| # FIXME?: assumes result is a single tensor |
| assert isinstance(result, TensorLike) |
| return _maybe_convert_to_dtype(result, result_dtype) |
| |
| _fn.__signature__ = sig # type: ignore[attr-defined] |
| return _fn |
| |
| |
| # TODO: handle tuples of tensors |
| def _maybe_resize_out(out: TensorLikeType, shape): |
| if out.numel() == 0: |
| return prims.resize(out, shape) |
| |
| if out.numel() != reduce(operator.mul, shape, 1): |
| msg = ( |
| "An output with one or more elements was resized since it had shape {0} " |
| "which does not match the required output shape {1}. " |
| "This behavior is deprecated, and in a future PyTorch release outputs will not " |
| "be resized unless they have zero elements. " |
| "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format( |
| str(out.shape), str(shape) |
| ) |
| ) |
| warnings.warn(msg) |
| return prims.resize(out, shape) |
| |
| return out |
| |
| |
| def _safe_copy_out(*, copy_from: TensorLikeType, copy_to: TensorLikeType): |
| # Checks same device |
| if copy_from.device != copy_to.device: |
| msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( |
| copy_from.device, copy_to.device |
| ) |
| raise RuntimeError(msg) |
| |
| # Checks safe cast |
| if not utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype): |
| msg = "Attempting to cast from {0} to out tensor with dtype {1}, but this can't be cast because it is not safe!".format( |
| copy_from.dtype, copy_to.dtype |
| ) |
| raise RuntimeError(msg) |
| |
| return prims.copy_to(copy_to, copy_from) |
| |
| |
| # FIXME: only supports out parameter that is literally called "out" |
| def out_wrapper(fn: Callable) -> Callable: |
| """ |
| Adds the out parameter to a Python reference. |
| |
| Note that this currently only supports operations that return a single tensor. |
| """ |
| |
| @wraps(fn) |
| def _fn(*args, out=None, **kwargs): |
| result = fn(*args, **kwargs) |
| if out is not None: |
| assert isinstance(out, TensorLike) |
| out = _maybe_resize_out(out, result.shape) |
| return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] |
| return result |
| |
| sig = inspect.signature(fn) |
| out_param = inspect.Parameter( |
| "out", |
| kind=inspect.Parameter.KEYWORD_ONLY, |
| default=None, |
| annotation=TensorLikeType, |
| ) |
| params = chain(sig.parameters.values(), (out_param,)) |
| _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] |
| parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] |
| ) |
| _fn.__annotations__ = fn.__annotations__ |
| _fn.__annotations__["out"] = TensorLikeType |
| return _fn |
| |
| |
| def out_wrapper_multi(*out_names): |
| def go(fn: Callable) -> Callable: |
| @wraps(fn) |
| def _fn(*args, **kwargs): |
| out_kwargs = {} |
| has_out_kwargs = None |
| for o in out_names: |
| out_kwargs[o] = kwargs.pop(o, None) |
| # Either all of the out kwargs are set or none of them |
| if has_out_kwargs is None: |
| has_out_kwargs = out_kwargs[o] is not None |
| else: |
| assert has_out_kwargs == (out_kwargs[o] is not None) |
| result = fn(*args, **kwargs) |
| assert isinstance(result, tuple) |
| if has_out_kwargs: |
| final_result = [] |
| for i, o in enumerate(out_names): |
| out = out_kwargs[o] |
| assert isinstance(out, TensorLike) |
| out = _maybe_resize_out(out, result[i].shape) |
| final_result.append(_safe_copy_out(copy_from=result[i], copy_to=out)) # type: ignore[arg-type] |
| return tuple(final_result) |
| return result |
| |
| sig = inspect.signature(fn) |
| out_params = [] |
| for o in out_names: |
| out_params.append( |
| inspect.Parameter( |
| o, |
| kind=inspect.Parameter.KEYWORD_ONLY, |
| default=None, |
| annotation=TensorLikeType, |
| ) |
| ) |
| params = chain(sig.parameters.values(), out_params) |
| _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] |
| parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] |
| ) |
| _fn.__annotations__ = fn.__annotations__ |
| for o in out_names: |
| _fn.__annotations__[o] = TensorLikeType |
| return _fn |
| |
| return go |
| |
| |
| # TODO: when tracing this will add torch tensors and not TensorMeta objects |
| # to the trace -- we should fix this by adding a tracing context and NumberMeta classes |
| # TODO: this wrapper is currently untested |
| def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: |
| """ |
| Allows unary operators that accept tensors to work with Python numbers. |
| """ |
| sig = inspect.signature(fn) |
| |
| @wraps(fn) |
| def _fn(*args, **kwargs): |
| if len(args) > 0 and isinstance(args[0], Number): |
| dtype = utils.type_to_dtype(type(args[0])) |
| args_ = list(args) |
| args_[0] = torch.tensor(args[0], dtype=dtype) |
| result = fn(*args_, **kwargs) |
| assert isinstance(result, torch.Tensor) |
| return result.item() |
| |
| return fn(*args, **kwargs) |
| |
| _fn.__signature__ = sig # type: ignore[attr-defined] |
| return _fn |
| |
| |
| # avoid mypy import cycle |
| import torch._prims as prims |