blob: c7fbd6fbf0ea55243928072e0c8dff5b1e6bd15b [file] [log] [blame]
# type: ignore
r'''
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**
FX is a toolkit for capturing and transforming functional PyTorch programs. It
consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed
with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's
`forward` method symbolically and record those operations in the FX intermediate representation.
```
import torch
from torch.fx import symbolic_trace
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = symbolic_trace(m)
```
The Intermediate Representation centers around a 5-opcode format:
```
print(gm.graph)
```
```
graph(x):
%linear_weight : [uses=1] = self.linear.weight
%add_1 : [uses=1] = call_function[target=<built-in function add>](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [uses=1] = call_function[target=<built-in method sum of type object at 0x7fad0a3c16a0>](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950
%topk_1 : [uses=1] = call_function[target=<built-in method topk of type object at 0x7fad0a3c16a0>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
return topk_1
```
The semantics are as follows:
- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
`target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to
the function parameters (e.g. `x`) in the graph printout.
- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
`args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign
to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function,
following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is
as previous. `target` is the fully-qualified name of the module in the module hierarchy to call.
`args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method
to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on,
_including the self argument_.
- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement
in the Graph printout.
GraphModule automatically generates Python code for the operations it symbolically observed:
```
print(gm.code)
```
```
def forward(self, x):
linear_weight = self.linear.weight
add_1 = x + linear_weight
linear_1 = self.linear(add_1)
relu_1 = linear_1.relu()
sum_1 = torch.sum(relu_1, dim = -1)
topk_1 = torch.topk(sum_1, 3)
return topk_1
```
Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another
`nn.Module` can be used, including in TorchScript tracing/compilation.
'''
from .graph_module import GraphModule
from .symbolic_trace import symbolic_trace, Tracer
from .graph import Graph
from .node import Node, map_arg
from .proxy import Proxy