| from typing import Callable, Iterable, Optional, Union |
| |
| from .custom_ops import custom_op |
| |
| |
| def triton_op( |
| name: str, |
| fn: Optional[Callable] = None, |
| /, |
| *, |
| mutates_args: Union[str, Iterable[str]], |
| schema: Optional[str] = None, |
| ) -> Callable: |
| """Create a custom operator whose implementation is backed by 1+ triton kernels. |
| |
| Use this instead of :func:`torch.library.custom_op` when the implementation |
| consists of 1+ triton kernels. :func:`torch.library.custom_op` treats |
| custom operators as opaque (:func:`torch.compile` and |
| :func:`torch.export.export` will never trace into them), but ``triton_op`` |
| makes the implementation visible to these subsystems, allowing them |
| to optimize the triton kernel(s). |
| |
| Note that ``fn`` must only consist of calls to PyTorch-understood |
| operators and triton kernels. Any triton kernels called inside ``fn`` |
| must be wrapped in a call to :func:`torch._library.capture_triton``. |
| |
| Args: |
| name (str): A name for the custom op that looks like "{namespace}::{name}", |
| e.g. "mylib::my_linear". The name is used as the op's stable identifier |
| in PyTorch subsystems (e.g. torch.export, FX graphs). |
| To avoid name collisions, please use your project name as the namespace; |
| e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. |
| mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. |
| This MUST be accurate, otherwise, the behavior is undefined. If "unknown", |
| it pessimistically assumes that all inputs to the operator are being mutated. |
| schema (None | str): A schema string for the operator. If None |
| (recommended) we'll infer a schema for the operator from its type |
| annotations. We recommend letting us infer a schema unless you |
| have a specific reason not to. |
| Example: "(Tensor x, int y) -> (Tensor, Tensor)". |
| |
| Example:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> import torch |
| >>> from torch._library import triton_op, capture_triton |
| >>> |
| >>> import triton |
| >>> from triton import language as tl |
| >>> |
| >>> @triton.jit |
| >>> def add_kernel( |
| >>> in_ptr0, |
| >>> in_ptr1, |
| >>> out_ptr, |
| >>> n_elements, |
| >>> BLOCK_SIZE: "tl.constexpr", |
| >>> ): |
| >>> pid = tl.program_id(axis=0) |
| >>> block_start = pid * BLOCK_SIZE |
| >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| >>> mask = offsets < n_elements |
| >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| >>> output = x + y |
| >>> tl.store(out_ptr + offsets, output, mask=mask) |
| >>> |
| >>> @triton_op("mylib::add", mutates_args={}) |
| >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| >>> output = torch.empty_like(x) |
| >>> n_elements = output.numel() |
| >>> |
| >>> def grid(meta): |
| >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| >>> |
| >>> # NB: we need to wrap the triton kernel in a call to capture_triton |
| >>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| >>> return output |
| >>> |
| >>> @torch.compile |
| >>> def f(x, y): |
| >>> return add(x, y) |
| >>> |
| >>> x = torch.randn(3, device="cuda") |
| >>> y = torch.randn(3, device="cuda") |
| >>> |
| >>> z = f(x, y) |
| >>> assert torch.allclose(z, x + y) |
| |
| """ |
| |
| def dec(fn: Callable) -> Callable: |
| result = custom_op(name, fn, mutates_args=mutates_args) |
| from .._subclasses.functional_tensor import FunctionalTensorMode |
| |
| # We require that the user pass us a function that is make_fx traceable, |
| # so we can just register it as the Fake/meta kernel. |
| result.register_fake(fn) |
| |
| # We decompose the operator when FunctionalTensorMode is active. |
| # The goal is to decompose the operator in AOTDispatcher. |
| # - With torch.compile, this means that the backend (usually Inductor) |
| # can see a call to the triton kernel(s) and so it can directly optimize |
| # them by inlining them into the lowering process. |
| # - With post-dispatch torch.export, this means that there will |
| # be a call(s) to the triton_kernel_wrapper_functional HOP in the |
| # graph (that we have yet to figure out how to serialize). |
| def functional_decomp( # type: ignore[no-untyped-def] |
| mode, _, types, args, kwargs |
| ): |
| with mode: |
| return fn(*args, **kwargs) |
| |
| result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) |
| return result |
| |
| if fn is None: |
| return dec |
| else: |
| return dec(fn) |
| |
| |
| def capture_triton(triton_kernel: Callable, /) -> Callable: |
| """Allows capture of a triton kernel into a graph via make_fx or |
| non-strict export (coming soon). |
| |
| These technologies perform Dispatcher-based tracing (via |
| ``__torch_dispatch__``) and cannot see calls to raw triton kernels. |
| The ``capture_triton`` API returns a new callable that can actually |
| be traced into a graph. |
| |
| Examples: |
| |
| >>> # xdoctest: +SKIP |
| >>> import torch |
| >>> import triton |
| >>> from triton import language as tl |
| >>> from torch.fx.experimental.proxy_tensor import make_fx |
| >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton |
| >>> |
| >>> @triton.jit |
| >>> def add_kernel( |
| >>> in_ptr0, |
| >>> in_ptr1, |
| >>> out_ptr, |
| >>> n_elements, |
| >>> BLOCK_SIZE: "tl.constexpr", |
| >>> ): |
| >>> pid = tl.program_id(axis=0) |
| >>> block_start = pid * BLOCK_SIZE |
| >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| >>> mask = offsets < n_elements |
| >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| >>> output = x + y |
| >>> tl.store(out_ptr + offsets, output, mask=mask) |
| >>> |
| >>> def add(x, y): |
| >>> output = torch.empty_like(x) |
| >>> n_elements = output.numel() |
| >>> |
| >>> def grid_fn(meta): |
| >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| >>> |
| >>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) |
| >>> return output |
| >>> |
| >>> x = torch.randn(3, device="cuda") |
| >>> y = torch.randn(3, device="cuda") |
| >>> gm = make_fx(add)(x, y) |
| >>> print(gm.code) |
| >>> # def forward(self, x_1, y_1): |
| >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) |
| >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( |
| >>> # kernel_idx = 0, constant_args_idx = 0, |
| >>> # grid = [(1, 1, 1)], kwargs = { |
| >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, |
| >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 |
| >>> # }) |
| >>> # return empty_like |
| |
| """ |
| from triton.runtime.autotuner import Autotuner |
| from triton.runtime.jit import JITFunction |
| |
| from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper |
| |
| if not isinstance(triton_kernel, (JITFunction, Autotuner)): |
| raise RuntimeError( |
| "capture_triton only works on functions annotated with triton.jit or triton.autotune" |
| ) |
| return TraceableTritonKernelWrapper(triton_kernel, None, None) |