blob: 53e8fc795b3c2082ff13b3b57174e1a090b1ac53 [file] [log] [blame]
torch.func interaction with torch.compile
==============================================
So you want to use a `torch.func` ("functorch") transform (like `vmap`, `grad`, `jacrev`, etc) with `torch.compile`. Here's a guide to what works today, what doesn't, and how to work around it.
Applying a `torch.func` transform to a `torch.compile`'d function
-----------------------------------------------------------------
This doesn't work and is being tracked by `https://github.com/pytorch/pytorch/issues/100320`.
.. code:: python
import torch
@torch.compile
def f(x):
return torch.sin(x)
def g(x):
return torch.grad(f)(x)
x = torch.randn(2, 3)
g(x)
As a workaround, please put the `torch.compile` outside of the `torch.func` transform:
.. code:: python
import torch
def f(x):
return torch.sin(x)
@torch.compile
def g(x):
return torch.vmap(f)(x)
x = torch.randn(2, 3)
g(x)
Doesn't work (PT 2.0): calling a `torch.func` transform inside of a `torch.compile`'ed function
------------------------------------------------------------------------------------------------
.. code:: python
import torch
@torch.compile
def f(x):
return torch.vmap(torch.sum)(x)
x = torch.randn(2, 3)
f(x)
This doesn't work yet. Please see the workaround (the next section).
Workaround: use `torch._dynamo.allow_in_graph`
----------------------------------------------
`allow_in_graph` is an escape hatch. If your code does not work with `torch.compile`, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like `jax.jit`), then use `allow_in_graph`.
By using `allow_in_graph` to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify:
- Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors.
- Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function.
- Your function does not raise data-dependent errors.
.. code:: python
import torch
@torch.compile
def f(x):
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
x = torch.randn(2, 3)
f(x)
A common pitfall is using `allow_in_graph` to annotate a function that invokes an `nn.Module`. This is because the outputs now depend on the parameters of the `nn.Module`. To actually get this to work, use `torch.func.functional_call` to extract the module state.