[FX] Add a bunch of docstrings (#47719)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47719
Test Plan: Imported from OSS
Reviewed By: zdevito
Differential Revision: D24875400
Pulled By: jamesr66a
fbshipit-source-id: a1dd43d2eee914a441eff43c4f2efe61a399e8a5
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index dd07ff7..45e5184 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -13,7 +13,7 @@
def _is_magic(x: str) -> bool:
return x.startswith('__') and x.endswith('__')
-def snake_case(s: str) -> str:
+def _snake_case(s: str) -> str:
return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_')
def get_qualified_name(func: Callable[..., Any]) -> str:
@@ -108,6 +108,68 @@
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
class Graph:
+ """
+ `Graph` is the main data structure used in the FX Intermediate Representation.
+ It consists of a series of `Node`s, each representing callsites (or other
+ syntactic constructs). The list of `Node`s, taken together, constitute a
+ valid Python function.
+
+ For example, the following code
+
+ ```
+ import torch
+ from torch.fx import symbolic_trace
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 4))
+ self.linear = torch.nn.Linear(4, 5)
+
+ def forward(self, x):
+ return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
+
+ m = MyModule()
+ gm = symbolic_trace(m)
+ ```
+
+ Will produce the following Graph:
+
+ ```
+ print(gm.graph)
+ ```
+
+ ```
+ graph(x):
+ %linear_weight : [uses=1] = self.linear.weight
+ %add_1 : [uses=1] = call_function[target=<built-in function add>](args = (%x, %linear_weight), kwargs = {})
+ %linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
+ %relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
+ %sum_1 : [uses=1] = call_function[target=<built-in method sum of type object at 0x7fad0a3c16a0>](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950
+ %topk_1 : [uses=1] = call_function[target=<built-in method topk of type object at 0x7fad0a3c16a0>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
+ return topk_1
+ ```
+
+ The Node semantics are as follows:
+
+ - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
+ `target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to
+ the function parameters (e.g. `x`) in the graph printout.
+ - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
+ fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
+ `args` and `kwargs` are don't-care
+ - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign
+ to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function,
+ following the Python calling convention
+ - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is
+ as previous. `target` is the fully-qualified name of the module in the module hierarchy to call.
+ `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
+ - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method
+ to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on,
+ _including the self argument_.
+ - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement
+ in the Graph printout.
+ """
def __init__(self):
"""
Construct an empty Graph.
@@ -118,7 +180,13 @@
self._len = 0
@property
- def nodes(self):
+ def nodes(self) -> _node_list:
+ """
+ Get the list of `Node`s that constitute this Graph.
+
+ Note that this `Node` list representation is a doubly-linked list. Mutations
+ during iteration (e.g. delete a Node, add a Node) are safe.
+ """
return _node_list(self)
def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]:
@@ -156,6 +224,21 @@
kwargs: Optional[Dict[str, Argument]] = None,
name: Optional[str] = None,
type_expr: Optional[Any] = None) -> Node:
+ """
+ Create a `Node` and add it to the `Graph` at the current insert-point.
+ Note that the current insert-point can be set via `Graph.inserting_before`
+ and `Graph.inserting_after`.
+
+ - op is the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
+ 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
+ described in the `Graph` docstring.
+ - args is a tuple of arguments to this node.
+ - kwargs is a dict from string to argument, representing the kwargs of this Node
+ - name is an optional string name for the `Node`. This will influence the name
+ of the value assigned to in the Python generated code.
+ - type_expr is an optional type annotation representing the Python type
+ the output of this node will have.
+ """
assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
@@ -224,16 +307,49 @@
# sugar for create_node when you know the op
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
+ """
+ Insert a `placeholder` node into the Graph. A `placeholder` represents
+ a function input. This function takes a string `name` for the input
+ value as well as an optional `type_expr`, which is a type expression
+ describing the type of value this input will take. The type expression
+ is needed in some cases for proper code generation.
+
+ The same insertion point rules apply for this method as `Graph.create_node`.
+ """
return self.create_node('placeholder', name, type_expr=type_expr)
- def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node:
- return self.create_node('get_attr', name, type_expr=type_expr)
+ def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
+ """
+ Insert a `get_attr` node into the Graph. A `get_attr` `Node` represents the
+ fetch of an attribute from the `Module` hierarchy. `qualified_name` is the
+ fully-qualified name of the attribute to be retrieved. For example, if
+ the traced Module has a submodule named `foo`, which has a submodule named
+ `bar`, which has an attribute named `baz`, the qualified name `foo.bar.baz`
+ should be passed as `qualified_name`.
+
+ The same insertion point and type expression rules apply for this method
+ as `Graph.create_node`.
+ """
+ return self.create_node('get_attr', qualified_name, type_expr=type_expr)
def call_module(self,
module_name: str,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
+ """
+ Insert a `call_module` `Node` into the `Graph`. A `call_module` node
+ represents a call to the forward() function of a `Module` in the `Module`
+ hierarchy. For example, if the traced `Module` has a submodule named `foo`,
+ which has a submodule named `bar`, the qualified name `foo.bar` should
+ be passed as `module_name` to call that module.
+
+ `args` and `kwargs` represent the args and kwargs passed to the called
+ `Module`, respectively.
+
+ The same insertion point and type expression rules apply for this method
+ as `Graph.create_node`.
+ """
return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
def call_method(self,
@@ -241,6 +357,18 @@
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
+ """
+ Insert a `call_method` `Node` into the `Graph`. A `call_method` node
+ represents a call to a given method on the 0th element of `args.
+ For example, if args[0] is a `Node` representing a `Tensor`, then to call
+ `relu()` on that `Tensor`, pass `relu` to `method_name`.
+
+ `args` and `kwargs` represent the args and kwargs passed to the called
+ method, respectively.
+
+ The same insertion point and type expression rules apply for this method
+ as `Graph.create_node`.
+ """
return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
def call_function(self,
@@ -248,10 +376,22 @@
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
+ """
+ Insert a `call_function` `Node` into the `Graph`. A `call_function` node
+ represents a call to a Python callable, specified by `the_function`. `the_function`
+ can be any PyTorch operator, Python function, or member of the `builtins`
+ or `operator` namespaces.
+
+ `args` and `kwargs` represent the args and kwargs passed to the called
+ method, respectively.
+
+ The same insertion point and type expression rules apply for this method
+ as `Graph.create_node`.
+ """
return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
- """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
+ """ Copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
to the graph of self. Example:
g : torch.fx.Graph = ...
@@ -281,6 +421,14 @@
return self.create_node(node.op, node.target, args, kwargs, name, node.type)
def output(self, result: Argument, type_expr: Optional[Any] = None):
+ """
+ Insert an `output` `Node` into the `Graph`. An `output` node represents
+ a `return` statement in the Python code. `result` is the value that should
+ be returned.
+
+ The same insertion point and type expression rules apply for this method
+ as `Graph.create_node`.
+ """
return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
def _name(self, target: Target) -> str:
@@ -294,7 +442,7 @@
op = op.replace('.', '_')
# delete all characters that are illegal in a Python identifier
op = re.sub('[^0-9a-zA-Z_]+', '_', op)
- op = snake_case(op)
+ op = _snake_case(op)
if op[0].isdigit():
op = f'_{op}'
@@ -318,6 +466,9 @@
return f'{op}_{i}'
def python_code(self, root_module: str) -> str:
+ """
+ Turn this `Graph` into valid Python code.
+ """
free_vars: List[str] = []
modules_used : Set[str] = set()
body: List[str] = []
@@ -405,6 +556,10 @@
return fn_code
def __str__(self) -> str:
+ """
+ Print a human-readable (not machine-readable) string representation
+ of this Graph
+ """
placeholder_names : List[str] = []
# This is a one-element array just so `format_node` can modify the closed
# over value
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 4a7ab8c..3525e18 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -172,11 +172,19 @@
@property
def graph(self):
+ """
+ Return the `Graph` underlying this `GraphModule`
+ """
return self._graph
@graph.setter
- def graph(self, val) -> None:
- self._graph = val
+ def graph(self, g) -> None:
+ """
+ Set the underlying `Graph` for this `GraphModule`. This will internally
+ recompile the `GraphModule` so that the generated `forward()` function
+ corresponds to `g`
+ """
+ self._graph = g
self.recompile()
def recompile(self) -> None:
@@ -204,6 +212,13 @@
cls.__call__ = wrapped_call
def __reduce__(self):
+ """
+ Serialization of GraphModule. We serialize only the generated code, not
+ the underlying `Graph`. This is because `Graph` does not have on-disk
+ backward-compatibility guarantees, whereas Python source code does.
+ On the deserialization side, we symbolically trace through the generated
+ code to regenerate the underlying `Graph`
+ """
dict_without_graph = self.__dict__.copy()
del dict_without_graph['_graph']
return (deserialize_graphmodule, (dict_without_graph,))
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 118b32f..dd304a8 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -59,10 +59,16 @@
@property
def next(self) -> 'Node':
+ """
+ Get the next node in the linked list
+ """
return self._next
@property
def prev(self) -> 'Node':
+ """
+ Get the previous node in the linked list
+ """
return self._prev
def prepend(self, x: 'Node'):
@@ -96,18 +102,38 @@
@property
def args(self) -> Tuple[Argument, ...]:
+ """
+ Return the tuple of arguments to this Node. The interpretation of arguments
+ depends on the node's opcode. See the `fx.Graph` docstring for more
+ information.
+ """
return self._args
@args.setter
def args(self, a : Tuple[Argument, ...]):
+ """
+ Set the tuple of arguments to this Node. The interpretation of arguments
+ depends on the node's opcode. See the `fx.Graph` docstring for more
+ information.
+ """
self._update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore
@property
def kwargs(self) -> Dict[str, Argument]:
+ """
+ Return the dict of kwargs to this Node. The interpretation of arguments
+ depends on the node's opcode. See the `fx.Graph` docstring for more
+ information.
+ """
return self._kwargs
@kwargs.setter
def kwargs(self, k : Dict[str, Argument]):
+ """
+ Set the dict of kwargs to this Node. The interpretation of arguments
+ depends on the node's opcode. See the `fx.Graph` docstring for more
+ information.
+ """
self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
@@ -151,7 +177,7 @@
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
- """ apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+ """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
if isinstance(a, tuple):
return tuple(map_arg(elem, fn) for elem in a)
if isinstance(a, list):
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 20566bb..a75d5ff 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -38,10 +38,34 @@
# instead, let's make python think that args and kwargs are normal variables
class Tracer(TracerBase):
+ """
+ `Tracer` is the class that implements the symbolic tracing functionality
+ of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent
+ to `Tracer().trace(m)`.
+
+ Tracer can be subclassed to override various behaviors of the tracing
+ process. The different behaviors that can be overridden are described
+ in the docstrings of the methods on this class.
+ """
def __init__(self):
super().__init__()
def create_arg(self, a: Any) -> Argument:
+ """
+ A method to specify the behavior of tracing when preparing values to
+ be used as arguments to nodes in the `Graph`.
+
+ By default, the behavior includes:
+ - Iterate through collection types (e.g. tuple, list, dict) and recursively
+ call `create_args` on the elements.
+ - Given a Proxy object, return a reference to the underlying IR `Node`
+ - Given a non-Proxy Tensor object, emit IR for various cases:
+ - For a Parameter, emit a `get_attr` node referring to that Parameter
+ - For a non-Parameter Tensor, store the Tensor away in a special
+ attribute referring to that attribute.
+
+ This method can be overridden to support more types.
+ """
# The base tracer is used to construct Graphs when there is no associated
# module hierarchy, so it can never create parameter references.
# The default tracer adds the ability to refer to parameters when
@@ -95,19 +119,43 @@
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
- def path_of_module(self, mod):
+ def path_of_module(self, mod) -> str:
+ """
+ Helper method to find the qualified name of `mod` in the Module hierarchy
+ of `root`. For example, if `root` has a submodule named `foo`, which has
+ a submodule named `bar`, passing `bar` into this function will return
+ the string "foo.bar".
+ """
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
+ """
+ Method that specifies the behavior of this `Tracer` when it encounters
+ a call to an `nn.Module` instance.
+
+ By default, the behavior is to check if the called module is a leaf module
+ via `is_leaf_module`. If it is, emit a `call_module` node referring to
+ `m` in the `Graph`. Otherwise, call the `Module` normally, tracing through
+ the operations in its `forward` function.
+
+ This method can be overridden to--for example--create nested traced
+ GraphModules, or any other behavior you would want while tracing across
+ `Module` boundaries.
+ """
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
def create_args_for_root(self, root_fn, is_module):
+ """
+ Create `placeholder` nodes corresponding to the signature of the `root`
+ Module. This method introspects `root`'s signature and emits those
+ nodes accordingly, also supporting *args and **kwargs.
+ """
# In some cases, a function or method has been decorated with a wrapper
# defined via `functools.wraps`. In this case, the outer code object
# will likely not contain the actual parameters we care about, so unwrap
@@ -149,6 +197,10 @@
return root_fn, args
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
+ """
+ Trace `root` and return the corresponding FX `Graph` representation. `root`
+ can either be an `nn.Module` instance or a Python callable.
+ """
if isinstance(root, torch.nn.Module):
self.root = root
fn = type(root).forward
@@ -211,12 +263,11 @@
return self.graph
-# Symbolic tracing API
-#
-# Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`
-# constructed by recording operations seen while tracing through `root`.
-#
-# Args:
-# - root - the `nn.Module` instance to trace
def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
+ """
+ Symbolic tracing API
+
+ Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`
+ constructed by recording operations seen while tracing through `root`.
+ """
return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root))