| import inspect |
| import typing |
| import weakref |
| |
| from torchgen.model import FunctionSchema, OperatorName, SchemaKind |
| |
| import torch |
| import torch._C as _C |
| import torch.library as library |
| import torch.utils._pytree as pytree |
| |
| """ |
| There are various APIs for defining custom-operator-like things in PyTorch: |
| - [user-facing] autograd.Function (Python) |
| - [user-facing] custom_op (Python) |
| - [for power users] torch.library (Python) |
| - [for power users] TORCH_LIBRARY (C++) |
| |
| This file contains the implementation for a Simple Custom Operator API (CustomOp). |
| Using CustomOp, you are able to define a custom operator and implement interactions |
| between the CustomOp and various PyTorch subsystems, including all the subsystems |
| that are necessary for a custom operator to work with torch.compile (i.e., |
| autograd, meta, functionalization). |
| |
| CustomOp is positioned as being safer and easier to use than |
| torch.library/TORCH_LIBRARY, which require deep understanding of PyTorch internals. |
| In additional, it supports torch.compile better than and is in general more |
| comprehensive than autograd.Function, which only supports implementing gradient |
| computation and vmap rules. |
| """ |
| |
| __all__ = ["custom_op", "CustomOp"] |
| |
| |
| SUPPORTED_DEVICE_TYPE_TO_KEY = { |
| "cpu": "CPU", |
| "cuda": "CUDA", |
| } |
| |
| # We will not let users register CustomOps with anything that could look like |
| # PyTorch internals to avoid confusion. |
| RESERVED_NS = { |
| "prim", |
| "prims", |
| "aten", |
| "at", |
| "torch", |
| "pytorch", |
| } |
| |
| |
| def custom_op(schema: str, *, ns: str) -> typing.Callable: |
| r"""Creates a new CustomOp object. |
| |
| In PyTorch, defining an op (short for "operator") is a two step-process: |
| - we need to define (create) the op |
| - we need to implement behavior for how the operator interacts with |
| various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. |
| |
| This entrypoint defines the CustomOp object (the first step); |
| you must then perform the second step by calling various methods on |
| the CustomOp object. |
| |
| This API is used as a decorator (see examples). |
| |
| Arguments: |
| schema (str): The schema of the CustomOp. |
| ns (str): The namespace of the CustomOp. PyTorch operators need a |
| namespace; a given operator may only be created once. If you |
| are writing a Python library, we recommend the namespace to be |
| the name of your top-level module. |
| |
| Example:: |
| >>> import numpy as np |
| >>> |
| >>> # Step 1: define the CustomOp. |
| >>> # We need to provide the decorator a "prototype function" |
| >>> # (a function with Python ellipses as the body). |
| >>> @custom_op('(Tensor x) -> Tensor') |
| >>> def numpy_sin(x): |
| >>> ... |
| >>> |
| >>> # numpy_sin is now an instance of class CustomOp |
| >>> print(type(numpy_sin)) |
| >>> |
| >>> # Step 2: Register an implementation for various PyTorch subsystems |
| >>> |
| >>> # Register an implementation for CPU tensors |
| >>> @numpy_sin.impl('cpu'): |
| >>> def numpy_sin_impl_cpu(x): |
| >>> return torch.from_numpy(np.sin(x.numpy())) |
| >>> |
| >>> # Register an implementation for CUDA tensors |
| >>> @numpy_sin.impl('cuda'): |
| >>> def numpy_sin_impl_cuda(x): |
| >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) |
| >>> |
| >>> x = torch.randn(3) |
| >>> numpy_sin(x) # calls numpy_sin_impl_cpu |
| >>> |
| >>> x_cuda = x.cuda() |
| >>> numpy_sin(x) # calls numpy_sin_impl_cuda |
| |
| """ |
| |
| def inner(func): |
| if not inspect.isfunction(func): |
| raise ValueError( |
| f"custom_op(...)(func): Expected `func` to be a Python " |
| f"function, got: {type(func)}" |
| ) |
| |
| validate_namespace(ns) |
| schema_str = f"{func.__name__}{schema}" |
| function_schema = FunctionSchema.parse(schema_str) |
| validate_schema(function_schema) |
| validate_function_matches_schema(function_schema, func) |
| |
| lib = library.Library(ns, "FRAGMENT") |
| lib.define(schema_str) |
| ophandle = find_ophandle_or_throw(ns, function_schema.name) |
| result = CustomOp(lib, ns, function_schema.name, ophandle, _private_access=True) |
| |
| result.__name__ = func.__name__ |
| result.__module__ = func.__module__ |
| result.__doc__ = func.__doc__ |
| |
| # NYI: autograd not supported |
| # In the near future we will either directly use the |
| # autograd_not_implemented kernels or make those the default fallback |
| # for the Autograd and ADInplaceOrView keys. Both of those are a bit tricky. |
| library.impl(lib, result._opname, "Autograd")( |
| get_autograd_not_implemented_kernel(weakref.proxy(result)) |
| ) |
| |
| return result |
| |
| return inner |
| |
| |
| class CustomOp: |
| r"""Class for custom operators in PyTorch. |
| |
| Use the CustomOp API to create user-defined custom operators that behave |
| just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it |
| comes to various PyTorch subsystems (like torch.compile). |
| |
| To construct a `CustomOp`, use `custom_op`. |
| """ |
| |
| def __init__(self, lib, cpp_ns, operator_name, ophandle, *, _private_access=False): |
| super(CustomOp, self).__init__() |
| if not _private_access: |
| raise RuntimeError( |
| "The CustomOp constructor is private and we do not guarantee " |
| "BC for it. Please use custom_op(...) to create a CustomOp object" |
| ) |
| name = f"{cpp_ns}::{str(operator_name.name)}" |
| self._lib: library.Library = lib |
| self._ophandle: _C._DispatchOperatorHandle = ophandle |
| # Has the name of the op, e.g. "foo". We cache here for convenience. |
| self._opname: str = str(operator_name) |
| # this is _opname but with namespace. e.g. "custom::foo" |
| self._namespaced_opname: str = name |
| self.__name__ = None # mypy requires this |
| |
| def __repr__(self): |
| return f'<CustomOp(op="{self._namespaced_opname}")>' |
| |
| def __call__(self, *args, **kwargs): |
| # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. |
| # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime |
| # issues from caching operators that make testing CustomOp difficult). |
| result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) |
| return result |
| |
| def impl( |
| self, device_types: typing.Union[str, typing.Iterable[str]] |
| ) -> typing.Callable: |
| r"""Register an implementation for a device type for this CustomOp object. |
| |
| If the CustomOp is passed multiple Tensor inputs with different device |
| types, it will dispatch to the registered implementation for the highest |
| priority device type among those present. |
| The supported device types, in order of priority, are {'cuda', 'cpu'}. |
| |
| This API is used as a decorator (see examples). |
| |
| Arguments: |
| device_types (str or Iterable[str]): the device type(s) to register the function for. |
| |
| Examples:: |
| >>> import numpy as np |
| >>> |
| >>> @custom_op('(Tensor x) -> Tensor', ns='custom') |
| >>> def numpy_sin(x): |
| >>> ... |
| >>> |
| >>> # Register an implementation for CPU Tensors |
| >>> @numpy_sin.impl('cpu'): |
| >>> def numpy_sin_impl_cpu(x): |
| >>> return torch.from_numpy(np.sin(x.numpy())) |
| >>> |
| >>> # Register an implementation for CUDA Tensors |
| >>> @numpy_sin.impl('cuda'): |
| >>> def numpy_sin_impl_cuda(x): |
| >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) |
| >>> |
| >>> x = torch.randn(3) |
| >>> numpy_sin(x) # calls numpy_sin_impl_cpu |
| >>> |
| >>> x_cuda = x.cuda() |
| >>> numpy_sin(x) # calls numpy_sin_impl_cuda |
| |
| """ |
| if isinstance(device_types, str): |
| device_types = [device_types] |
| for device_type in device_types: |
| validate_device_type(device_type) |
| |
| def inner(f): |
| for device_type in set(device_types): |
| dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] |
| library.impl(self._lib, self._opname, dispatch_key)(f) |
| return f |
| |
| return inner |
| |
| def impl_meta(self) -> typing.Callable: |
| r"""Register a meta implementation for this CustomOp object. |
| |
| The meta implementation is a shape propagation rule that gets invoked |
| for device='meta' Tensors and FakeTensors (Tensors that do not have storage). |
| |
| To register a data-dependent shape propagation rule, use the |
| not-yet-implemented method to register a rule for FakeTensor. |
| |
| This API is used as a decorator (see examples). |
| |
| Examples:: |
| >>> import numpy as np |
| >>> |
| >>> @custom_op('(Tensor x) -> Tensor', ns='custom') |
| >>> def numpy_sin(x): |
| >>> ... |
| >>> |
| >>> @custom_sum.impl_meta(): |
| >>> def custom_sum(tensor, dim): |
| >>> output_shape = list(tensor.shape) |
| >>> del output_shape[dim] |
| >>> return tensor.new_empty(output_shape) |
| >>> |
| >>> x = torch.randn(2, 3, device='meta') |
| >>> y = custom_sum(x, 1) |
| >>> assert y.shape == (2,) |
| |
| """ |
| |
| def inner(f): |
| library.impl(self._lib, self._opname, "Meta")(f) |
| return f |
| |
| return inner |
| |
| |
| def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): |
| overload_name = ( |
| "" if operator_name.overload_name is None else operator_name.overload_name |
| ) |
| return _C._dispatch_find_schema_or_throw( |
| f"{cpp_ns}::{str(operator_name.name)}", overload_name |
| ) |
| |
| |
| def validate_namespace(ns: str) -> None: |
| if "." in ns: |
| raise ValueError( |
| f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' |
| f"valid variable name)" |
| ) |
| if ns in RESERVED_NS: |
| raise ValueError( |
| f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " |
| f"please choose something else. " |
| ) |
| |
| |
| def validate_schema(schema: FunctionSchema) -> None: |
| # Coming in the future. Requires us to have correct logic for |
| # the ADInplaceOrView key |
| if schema.kind() != SchemaKind.functional: |
| raise ValueError( |
| f"custom_op does not support non-functional function schema. Got: {schema}" |
| ) |
| |
| rets = schema.returns |
| is_non_mutating_view = len(rets) > 0 and any( |
| r.annotation is not None and not r.annotation.is_write for r in rets |
| ) |
| if is_non_mutating_view: |
| raise ValueError(f"custom_op does not support view functions. Got: {schema}") |
| |
| # Requires us to have handling for factory functions |
| if not schema.arguments.has_tensor_arg(): |
| raise ValueError( |
| f"custom_op does not support function schema with no Tensor inputs. Got: {schema}" |
| ) |
| # Just seems weird so banning for now |
| if not schema.returns: |
| raise ValueError( |
| f"custom_op does not support function schema with no outputs. Got: {schema}" |
| ) |
| |
| # For simplicity: don't allow self arguments |
| if schema.arguments.self_arg is not None: |
| raise ValueError( |
| f"custom_op does not support arguments named 'self'. Please " |
| f"rename your argument. Got: {schema}" |
| ) |
| |
| |
| def parse_namespace(namespaced_entity: str) -> typing.Tuple[str, str]: |
| names = namespaced_entity.split("::", 1) |
| if len(names) != 2: |
| raise ValueError(f"Expected there to be a namespace in {namespaced_entity}.") |
| return names[0], names[1] |
| |
| |
| def validate_device_type(device_type: str) -> None: |
| if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: |
| raise ValueError( |
| f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " |
| f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." |
| ) |
| |
| |
| def get_autograd_not_implemented_kernel(custom_op) -> typing.Callable: |
| def autograd_not_implemented(*args, **kwargs) -> None: |
| if pytree.tree_any( |
| lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) |
| ): |
| raise RuntimeError("Autograd has not been implemented for operator") |
| guard = _C._AutoDispatchBelowAutograd() |
| try: |
| return custom_op(*args, **kwargs) |
| finally: |
| del guard |
| |
| return autograd_not_implemented |
| |
| |
| def validate_function_matches_schema( |
| schema: FunctionSchema, func: typing.Callable |
| ) -> None: |
| arg_spec = inspect.getfullargspec(func) |
| |
| arg_names = tuple(arg.name for arg in schema.arguments.post_self_positional) |
| if arg_names != tuple(arg_spec.args): |
| raise ValueError( |
| f"custom_op: Expected the schema to match the signature of `func`. " |
| f"Schema has arg names {arg_names} but function has {arg_spec.args}." |
| ) |
| |
| kwonlyarg_names = tuple( |
| arg.name for arg in schema.arguments.pre_tensor_options_kwarg_only |
| ) |
| if kwonlyarg_names != tuple(arg_spec.kwonlyargs): |
| raise ValueError( |
| f"custom_op: Expected the schema to match the signature of `func`. " |
| f"Schema has kwonlyarg names {kwonlyarg_names} but function has " |
| f"{arg_spec.kwonlyargs}." |
| ) |