| import torch |
| from typing import List |
| |
| __all__ = [ |
| "compile", |
| "assume_constant_result", |
| "reset", |
| "allow_in_graph", |
| "list_backends", |
| "disable", |
| "cudagraph_mark_step_begin", |
| "wrap_numpy", |
| "is_compiling", |
| "is_dynamo_compiling", |
| ] |
| |
| def compile(*args, **kwargs): |
| """ |
| See :func:`torch.compile` for details on the arguments for this function. |
| """ |
| return torch.compile(*args, **kwargs) |
| |
| def reset() -> None: |
| """ |
| This function clears all compilation caches and restores the system to its initial state. |
| It is recommended to call this function, especially after using operations like `torch.compile(...)` |
| to ensure a clean state before another unrelated compilation |
| """ |
| import torch._dynamo |
| |
| torch._dynamo.reset() |
| |
| def allow_in_graph(fn): |
| """ |
| Customize which functions compilation will include in the generated graph. |
| It bypasses all introspection of the symbolic python code in favor of |
| directly writing it to the graph. |
| If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()` |
| to each function and returns a new list or tuple containing the modified functions |
| |
| Args: |
| fn: A callable representing the function to be included in the graph. |
| |
| .. warning:: |
| |
| :func:`allow_in_graph` skips TorchDynamo completely on the decorated function |
| skipping all TorchDynamo safety checks (graph breaks, handling closures, etc). |
| Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems |
| like AOT Autograd rely on torchdynamo |
| If not careful, this could lead to soundness and really hard-to-debug issues. |
| |
| """ |
| import torch._dynamo |
| |
| return torch._dynamo.allow_in_graph(fn) |
| |
| |
| def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: |
| """ |
| Return valid strings that can be passed to `torch.compile(..., backend="name")`. |
| |
| Args: |
| exclude_tags(optional): A tuple of strings representing tags to exclude. |
| """ |
| import torch._dynamo |
| |
| return torch._dynamo.list_backends(exclude_tags) |
| |
| def assume_constant_result(fn): |
| """ |
| This function is used to mark a function `fn` as having a constant result. |
| This allows the compiler to optimize away your function |
| Returns The same function `fn` |
| |
| Args: |
| fn: The function to be marked as having a constant result. |
| |
| .. warning:: |
| `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile` |
| will not attempt to validate whether the constant assumption is true or not |
| |
| """ |
| import torch._dynamo |
| |
| return torch._dynamo.assume_constant_result(fn) |
| |
| def disable(fn=None, recursive=True): |
| """ |
| This function provides both a decorator and a context manager to disable compilation on a function |
| It also provides the option of recursively disabling called functions |
| |
| Args: |
| fn (optional): The function to disable |
| recursive (optional): A boolean value indicating whether the disabling should be recursive. |
| """ |
| import torch._dynamo |
| |
| return torch._dynamo.disable(fn, recursive) |
| |
| def cudagraph_mark_step_begin(): |
| """ |
| Indicates that a new iteration of inference or training is about to begin. |
| |
| CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of |
| torch.compile, so long as there is not a pending backward that has not been called. |
| |
| If that heuristic is wrong, such as in the following example, manually mark it with this api. |
| |
| .. code-block:: python |
| |
| @torch.compile(mode="reduce-overhead") |
| def rand_foo(): |
| return torch.rand([4], device="cuda") |
| |
| for _ in range(5): |
| torch.compiler.cudagraph_mark_step_begin() |
| rand_foo() + rand_foo() |
| |
| For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__ |
| """ |
| from torch._inductor import cudagraph_trees |
| |
| cudagraph_trees.mark_step_begin() |
| |
| def wrap_numpy(fn): |
| r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function |
| from ``torch.Tensor``s to ``torch.Tensor``s. |
| |
| It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to |
| compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code |
| on CUDA or compute its gradients. |
| |
| .. note:: |
| |
| This decorator does not work without :func:`torch.compile`. |
| |
| Example:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> # Compile a NumPy function as a Tensor -> Tensor function |
| >>> @torch.compile(fullgraph=True) |
| >>> @torch.compiler.wrap_numpy |
| >>> def fn(a: np.ndarray): |
| >>> return np.sum(a * a) |
| >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients |
| >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True) |
| >>> out = fn(x) |
| >>> out.backward() |
| >>> print(x.grad) |
| tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0') |
| """ |
| from torch._dynamo.external_utils import wrap_numpy as wrap |
| return wrap(fn) |
| |
| _is_compiling_flag: bool = False |
| |
| def is_compiling() -> bool: |
| """ |
| Indicates whether a graph is executed/traced as part of torch.compile() or torch.export(). |
| |
| Note that there are 2 other related flags that should deprecated eventually: |
| * torch._dynamo.external_utils.is_compiling() |
| * torch._utils.is_compiling() |
| |
| Example:: |
| |
| >>> def forward(self, x): |
| >>> if not torch.compiler.is_compiling(): |
| >>> ...logic that is not needed in a compiled/traced graph... |
| >>> |
| >>> ...rest of the function... |
| """ |
| if torch.jit.is_scripting(): |
| return False |
| else: |
| return _is_compiling_flag |
| |
| def is_dynamo_compiling() -> bool: |
| """ |
| Indicates whether a graph is traced via TorchDynamo. |
| |
| It's stricter than is_compiling() flag, as it would only be set to True when |
| TorchDynamo is used. |
| |
| Example:: |
| |
| >>> def forward(self, x): |
| >>> if not torch.compiler.is_dynamo_compiling(): |
| >>> ...logic that is not needed in a TorchDynamo-traced graph... |
| >>> |
| >>> ...rest of the function... |
| """ |
| return False |