| import torch |
| from torch.library import Library |
| from torch._ops import OpOverload |
| from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseTy, BaseType |
| from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey |
| from .autograd import autograd_not_implemented |
| import torch.utils._pytree as pytree |
| import weakref |
| |
| |
| def register_functional_op( |
| lib: Library, |
| new_op_name: str, |
| mutable_op: OpOverload, |
| ) -> None: |
| """Given a mutable operator, registers the functional variant. |
| |
| This API also correctly links the functional variant with the mutable |
| operator for the purposes of functionalization. |
| |
| All of the new registrations are performed on the ``lib`` passed in. |
| |
| Arguments: |
| lib (Library): Should be a torch.library.Library object that has |
| the same namespace as ``mutable_op``'s namespace. |
| lib will be used to register the new functional op as well |
| as a functionalization kernel for the ``mutable_op`` |
| If you don't have a library handy, use |
| ``torch.library.Library(ns, 'FRAGMENT')`` to construct one. |
| new_op_name (str): The name of the functional operator (without the |
| namespace). If no namespace, the new functional variant will be |
| accessible under ``torch.ops.{lib.ns}.new_op_name``. |
| mutable_op (OpOverload): The mutable custom operator. Note |
| that you may need to add a `.default` to it, like |
| `torch.ops.aten.abs_.default`. |
| |
| """ |
| validate(mutable_op) |
| schema = functional_schema(new_op_name, mutable_op) |
| lib.define(schema) |
| |
| functional_impl = construct_functional_impl(mutable_op) |
| lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') |
| |
| functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default |
| |
| # There's no easy way for us to generate the autograd kernel, so we |
| # use autograd_not_implemented. Also, this makes it so that the user |
| # is unable to register an autograd formula themselves. This shouldn't |
| # be a problem if the user doesn't use the functional op direclty |
| # in their program, but we may need to revist this in the future. |
| lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') |
| |
| f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) |
| |
| lib.impl(mutable_op, f_kernel, 'Functionalize') |
| |
| |
| def construct_functional_impl(mutable_op): |
| def functional_impl(*args): |
| # Strategy: |
| # - clone args that would have been mutated |
| # - run mutable_op |
| # - return the cloned args as additional outputs |
| new_args = [] |
| extra_rets = [] |
| for is_write, arg in zip(mutable_args(mutable_op), args): |
| if is_write: |
| cloned = arg.clone() |
| new_args.append(cloned) |
| extra_rets.append(cloned) |
| else: |
| new_args.append(arg) |
| result = mutable_op(*new_args) |
| if result is None: |
| return tuple(extra_rets) |
| if isinstance(result, tuple): |
| return (*result, *extra_rets) |
| return (result, *extra_rets) |
| return functional_impl |
| |
| |
| def construct_functionalization_kernel(mutable_op, functional_op): |
| def kernel(*args): |
| # There's nothing to be functionalized! |
| # We can still end up here because DispatchKey::Functionalize is a mode key |
| if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): |
| with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): |
| return mutable_op(*args) |
| |
| # NB: This differs from the codegen -- codegen handles cases where there |
| # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. |
| # This only really matters for XLA (mixed CPU-XLA tensors) and |
| # running functionalization without the PT2 stack (which guarantees to us that |
| # all tensors are FunctionalTensorWrapper). |
| if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): |
| raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") |
| |
| unwrapped_args = [] |
| for arg in args: |
| if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): |
| torch._sync(arg) |
| unwrapped = torch._from_functional_tensor(arg) |
| unwrapped_args.append(unwrapped) |
| else: |
| unwrapped_args.append(arg) |
| |
| with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): |
| output = functional_op(*unwrapped_args) |
| |
| num_actual_output = len(mutable_op._schema.returns) |
| actual_output = pytree.tree_map( |
| torch._to_functional_tensor, output[:num_actual_output]) |
| |
| new_values_to_propagate = output[num_actual_output:] |
| inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) |
| if is_write] |
| assert len(new_values_to_propagate) == len(inputs_to_replace) |
| for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): |
| torch._C._propagate_xla_data(arg, new_value) |
| torch._C._replace_(arg, new_value) |
| torch._C._commit_update(arg) |
| torch._sync(arg) |
| |
| if len(actual_output) == 1: |
| return actual_output[0] |
| elif len(actual_output) == 0: |
| return None |
| return actual_output |
| |
| return kernel |
| |
| |
| def validate(mutable_op: OpOverload): |
| if not isinstance(mutable_op, OpOverload): |
| raise TypeError( |
| f"register_functional_op(mutable_op): expected mutable_op to be instance of " |
| f"OpOverload but got {type(mutable_op)}") |
| |
| # There are generally three types of "in-place" or "mutable" ops. |
| # Each of them have their own conventions: |
| # - inplace (first input modified in-place and returned as only output) |
| # - out= (some args modified in-place and returned as outputs) |
| # - mutable (some args modified in-place but none of those returned as outputs) |
| # In theory we can support all three, but we'll just support the last |
| # option right now for simplicity. |
| schema = FunctionSchema.parse(str(mutable_op._schema)) |
| if not schema.kind() == SchemaKind.mutable: |
| raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") |
| for ret in schema.returns: |
| # construct_functionalization_kernel assumes this for simplicity |
| if ret.annotation is not None: |
| raise NotImplementedError( |
| "NYI: register_functional_op(op) where op returns a mutated or aliased value. " |
| "Please file an issue (and as a workaround, modify your operator to " |
| "not return the mutated value or aliases)") |
| for arg in schema.arguments.flat_all: |
| # construct_functionalization_kernel assumes this for simplicity |
| if arg.type.is_tensor_like() and arg.type != BaseType(BaseTy.Tensor): |
| raise NotImplementedError( |
| "NYI: register_functional_op(op) where op accepts Optional or List of tensors." |
| "Please file an issue.") |
| |
| |
| def functional_schema(new_op_name, op: OpOverload): |
| schema = FunctionSchema.parse(str(op._schema)) |
| schema = schema.signature().with_name(OperatorName.parse(new_op_name)) |
| return str(schema) |
| |
| |
| def mutable_args(op: OpOverload): |
| return tuple(False if arg.alias_info is None else arg.alias_info.is_write |
| for arg in op._schema.arguments) |