[FX][2/2] Make docstrings pretty when rendered (#48871)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48871
Test Plan: Imported from OSS
Reviewed By: ansley
Differential Revision: D25351588
Pulled By: jamesr66a
fbshipit-source-id: 4c6fd341100594c204a35d6a3aab756e3e22297b
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 072aef6..ca4b8d6 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -148,26 +148,7 @@
%topk_1 : [#users=1] = call_function[target=<built-in method topk of type object at 0x7ff2da9dc300>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
return topk_1
- The Node semantics are as follows:
-
- - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
- ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
- denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
- the function parameters (e.g. ``x``) in the graph printout.
- - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
- fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
- ``args`` and ``kwargs`` are don't-care
- - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
- to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
- following the Python calling convention
- - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
- as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
- ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
- - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
- to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
- *including the self argument*
- - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
- in the Graph printout.
+ For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
"""
def __init__(self):
"""
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 1cc94be..fd8a4bc 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -21,8 +21,34 @@
]]
class Node:
- def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
- args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
+ """
+ ``Node`` is the data structure that represents individual operations within
+ a ``Graph``. For the most part, Nodes represent callsites to various entities,
+ such as operators, methods, and Modules (some exceptions include nodes that
+ specify function inputs and outputs). Each ``Node`` has a function specified
+ by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:
+
+ - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
+ ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
+ denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
+ the function parameters (e.g. ``x``) in the graph printout.
+ - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
+ fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
+ ``args`` and ``kwargs`` are don't-care
+ - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
+ to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
+ following the Python calling convention
+ - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
+ as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
+ ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
+ - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
+ to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
+ *including the self argument*
+ - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
+ in the Graph printout.
+ """
+ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
+ args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
type : Optional[Any] = None) -> None:
self.graph = graph
self.name = name # unique name of value being created
@@ -60,23 +86,33 @@
@property
def next(self) -> 'Node':
"""
- Get the next node in the linked list
+ Returns the next ``Node`` in the linked list of Nodes.
+
+ Returns:
+
+ The next ``Node`` in the linked list of Nodes.
"""
return self._next
@property
def prev(self) -> 'Node':
"""
- Get the previous node in the linked list
+ Returns the previous ``Node`` in the linked list of Nodes.
+
+ Returns:
+
+ The previous ``Node`` in the linked list of Nodes.
"""
return self._prev
- def prepend(self, x: 'Node'):
- """Insert x before this node in the list of nodes in the graph.
- Before: p -> self
- bx -> x -> ax
- After: p -> x -> self
- bx -> ax
+ def prepend(self, x: 'Node') -> None:
+ """
+ Insert x before this node in the list of nodes in the graph. Example::
+
+ Before: p -> self
+ bx -> x -> ax
+ After: p -> x -> self
+ bx -> ax
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
@@ -87,8 +123,9 @@
p._next, x._prev = x, p
x._next, self._prev = self, x
- def append(self, x: 'Node'):
- """Insert x after this node in the list of nodes in the graph.
+ def append(self, x: 'Node') -> None:
+ """
+ Insert x after this node in the list of nodes in the graph.
Equvalent to ``self.next.prepend(x)``
Args:
@@ -103,9 +140,12 @@
@property
def args(self) -> Tuple[Argument, ...]:
"""
- Return the tuple of arguments to this Node. The interpretation of arguments
- depends on the node's opcode. See the ``fx.Graph`` docstring for more
+ The tuple of arguments to this ``Node``. The interpretation of arguments
+ depends on the node's opcode. See the :class:`Node` docstring for more
information.
+
+ Assignment to this property is allowed. All accounting of uses and users
+ is updated automatically on assignment.
"""
return self._args
@@ -121,9 +161,12 @@
@property
def kwargs(self) -> Dict[str, Argument]:
"""
- Return the dict of kwargs to this Node. The interpretation of arguments
- depends on the node's opcode. See the ``fx.Graph`` docstring for more
+ The dict of keyword arguments to this ``Node``. The interpretation of arguments
+ depends on the node's opcode. See the :class:`Node` docstring for more
information.
+
+ Assignment to this property is allowed. All accounting of uses and users
+ is updated automatically on assignment.
"""
return self._kwargs
@@ -141,7 +184,12 @@
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over ``args`` and ``kwargs`` and only collecting the values that
- are Nodes
+ are Nodes.
+
+ Returns:
+
+ List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
+ ``Node``, in that order.
"""
all_nodes : List['Node'] = []
map_arg(self.args, lambda n: all_nodes.append(n))
@@ -149,6 +197,9 @@
return all_nodes
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
+ """
+ This API is internal. Do *not* call it directly.
+ """
self._args = new_args
self._kwargs = new_kwargs
@@ -168,7 +219,14 @@
def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']:
"""
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
- Returns the list of nodes on which this change was made.
+
+ Args:
+
+ replace_with (Node): The node to replace all uses of ``self`` with.
+
+ Returns:
+
+ The list of Nodes on which this change was made.
"""
to_process = list(self.users)
for use_node in to_process:
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 6bdc8dd..69e3c70 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -1,6 +1,6 @@
import inspect
from types import CodeType, FunctionType
-from typing import Any, Dict, Optional, List, Callable, Union
+from typing import Any, Dict, Optional, Tuple, List, Callable, Union
import torch
from torch._C import ScriptObject # type: ignore
@@ -51,21 +51,31 @@
def __init__(self):
super().__init__()
- def create_arg(self, a: Any) -> Argument:
+ def create_arg(self, a: Any) -> 'Argument':
"""
A method to specify the behavior of tracing when preparing values to
be used as arguments to nodes in the ``Graph``.
By default, the behavior includes:
- - Iterate through collection types (e.g. tuple, list, dict) and recursively
- call ``create_args`` on the elements.
- - Given a Proxy object, return a reference to the underlying IR ``Node``
- - Given a non-Proxy Tensor object, emit IR for various cases:
- - For a Parameter, emit a ``get_attr`` node referring to that Parameter
- - For a non-Parameter Tensor, store the Tensor away in a special
- attribute referring to that attribute.
+ #. Iterate through collection types (e.g. tuple, list, dict) and recursively
+ call ``create_args`` on the elements.
+ #. Given a Proxy object, return a reference to the underlying IR ``Node``
+ #. Given a non-Proxy Tensor object, emit IR for various cases:
+
+ * For a Parameter, emit a ``get_attr`` node referring to that Parameter
+ * For a non-Parameter Tensor, store the Tensor away in a special
+ attribute referring to that attribute.
This method can be overridden to support more types.
+
+ Args:
+
+ a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
+
+
+ Returns:
+
+ The value ``a`` converted into the appropriate ``Argument``
"""
# The base tracer is used to construct Graphs when there is no associated
# module hierarchy, so it can never create parameter references.
@@ -115,28 +125,32 @@
their constituent ops are recorded, unless specified otherwise
via this parameter.
- Args
- m - The module itself
- module_qualified_name - The path to root of this module. For example,
- if you have a module hierarchy where submodule ``foo`` contains
- submodule ``bar``, which contains submodule ``baz``, that module will
- appear with the qualified name ``foo.bar.baz`` here.
+ Args:
+ m (Module): The module being queried about
+ module_qualified_name (str): The path to root of this module. For example,
+ if you have a module hierarchy where submodule ``foo`` contains
+ submodule ``bar``, which contains submodule ``baz``, that module will
+ appear with the qualified name ``foo.bar.baz`` here.
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
- def path_of_module(self, mod) -> str:
+ def path_of_module(self, mod : torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy
of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
a submodule named ``bar``, passing ``bar`` into this function will return
the string "foo.bar".
+
+ Args:
+
+ mod (str): The ``Module`` to retrieve the qualified name for.
"""
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')
- def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
+ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
a call to an ``nn.Module`` instance.
@@ -149,6 +163,20 @@
This method can be overridden to--for example--create nested traced
GraphModules, or any other behavior you would want while tracing across
``Module`` boundaries.
+ ``Module`` boundaries.
+
+ Args:
+
+ m (Module): The module for which a call is being emitted
+ forward (Callable): The forward() method of the ``Module`` to be invoked
+ args (Tuple): args of the module callsite
+ kwargs (Dict): kwargs of the module callsite
+
+ Return:
+
+ The return value from the Module call. In the case that a ``call_module``
+ node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
+ value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
@@ -205,6 +233,16 @@
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
+
+
+ Args:
+
+ root (Union[Module, Callable]): Either a ``Module`` or a function to be
+ traced through.
+
+ Returns:
+
+ A ``Graph`` representing the semantics of the passed-in ``root``.
"""
if isinstance(root, torch.nn.Module):
self.root = root