[FX] Add a bunch of docstrings (#47719)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47719

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D24875400

Pulled By: jamesr66a

fbshipit-source-id: a1dd43d2eee914a441eff43c4f2efe61a399e8a5
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index dd07ff7..45e5184 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -13,7 +13,7 @@
 def _is_magic(x: str) -> bool:
     return x.startswith('__') and x.endswith('__')
 
-def snake_case(s: str) -> str:
+def _snake_case(s: str) -> str:
     return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_')
 
 def get_qualified_name(func: Callable[..., Any]) -> str:
@@ -108,6 +108,68 @@
         return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
 
 class Graph:
+    """
+    `Graph` is the main data structure used in the FX Intermediate Representation.
+    It consists of a series of `Node`s, each representing callsites (or other
+    syntactic constructs). The list of `Node`s, taken together, constitute a
+    valid Python function.
+
+    For example, the following code
+
+    ```
+    import torch
+    from torch.fx import symbolic_trace
+
+    class MyModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.param = torch.nn.Parameter(torch.rand(3, 4))
+            self.linear = torch.nn.Linear(4, 5)
+
+        def forward(self, x):
+            return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
+
+    m = MyModule()
+    gm = symbolic_trace(m)
+    ```
+
+    Will produce the following Graph:
+
+    ```
+    print(gm.graph)
+    ```
+
+    ```
+    graph(x):
+        %linear_weight : [uses=1] = self.linear.weight
+        %add_1 : [uses=1] = call_function[target=<built-in function add>](args = (%x, %linear_weight), kwargs = {})
+        %linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
+        %relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
+        %sum_1 : [uses=1] = call_function[target=<built-in method sum of type object at 0x7fad0a3c16a0>](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950
+        %topk_1 : [uses=1] = call_function[target=<built-in method topk of type object at 0x7fad0a3c16a0>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
+        return topk_1
+    ```
+
+    The Node semantics are as follows:
+
+    - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
+    `target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to
+    the function parameters (e.g. `x`) in the graph printout.
+    - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
+    fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
+    `args` and `kwargs` are don't-care
+    - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign
+    to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function,
+    following the Python calling convention
+    - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is
+    as previous. `target` is the fully-qualified name of the module in the module hierarchy to call.
+    `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
+    - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method
+    to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on,
+    _including the self argument_.
+    - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement
+    in the Graph printout.
+    """
     def __init__(self):
         """
         Construct an empty Graph.
@@ -118,7 +180,13 @@
         self._len = 0
 
     @property
-    def nodes(self):
+    def nodes(self) -> _node_list:
+        """
+        Get the list of `Node`s that constitute this Graph.
+
+        Note that this `Node` list representation is a doubly-linked list. Mutations
+        during iteration (e.g. delete a Node, add a Node) are safe.
+        """
         return _node_list(self)
 
     def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]:
@@ -156,6 +224,21 @@
                     kwargs: Optional[Dict[str, Argument]] = None,
                     name: Optional[str] = None,
                     type_expr: Optional[Any] = None) -> Node:
+        """
+        Create a `Node` and add it to the `Graph` at the current insert-point.
+        Note that the current insert-point can be set via `Graph.inserting_before`
+        and `Graph.inserting_after`.
+
+        - op is the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
+          'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
+          described in the `Graph` docstring.
+        - args is a tuple of arguments to this node.
+        - kwargs is a dict from string to argument, representing the kwargs of this Node
+        - name is an optional string name for the `Node`. This will influence the name
+          of the value assigned to in the Python generated code.
+        - type_expr is an optional type annotation representing the Python type
+          the output of this node will have.
+        """
         assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
         args = () if args is None else args
         kwargs = {} if kwargs is None else kwargs
@@ -224,16 +307,49 @@
 
     # sugar for create_node when you know the op
     def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a `placeholder` node into the Graph. A `placeholder` represents
+        a function input. This function takes a string `name` for the input
+        value as well as an optional `type_expr`, which is a type expression
+        describing the type of value this input will take. The type expression
+        is needed in some cases for proper code generation.
+
+        The same insertion point rules apply for this method as `Graph.create_node`.
+        """
         return self.create_node('placeholder', name, type_expr=type_expr)
 
-    def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node:
-        return self.create_node('get_attr', name, type_expr=type_expr)
+    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a `get_attr` node into the Graph. A `get_attr` `Node` represents the
+        fetch of an attribute from the `Module` hierarchy. `qualified_name` is the
+        fully-qualified name of the attribute to be retrieved. For example, if
+        the traced Module has a submodule named `foo`, which has a submodule named
+        `bar`, which has an attribute named `baz`, the qualified name `foo.bar.baz`
+        should be passed as `qualified_name`.
+
+        The same insertion point and type expression rules apply for this method
+        as `Graph.create_node`.
+        """
+        return self.create_node('get_attr', qualified_name, type_expr=type_expr)
 
     def call_module(self,
                     module_name: str,
                     args: Optional[Tuple[Argument, ...]] = None,
                     kwargs: Optional[Dict[str, Argument]] = None,
                     type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a `call_module` `Node` into the `Graph`. A `call_module` node
+        represents a call to the forward() function of a `Module` in the `Module`
+        hierarchy. For example, if the traced `Module` has a submodule named `foo`,
+        which has a submodule named `bar`, the qualified name `foo.bar` should
+        be passed as `module_name` to call that module.
+
+        `args` and `kwargs` represent the args and kwargs passed to the called
+        `Module`, respectively.
+
+        The same insertion point and type expression rules apply for this method
+        as `Graph.create_node`.
+        """
         return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
 
     def call_method(self,
@@ -241,6 +357,18 @@
                     args: Optional[Tuple[Argument, ...]] = None,
                     kwargs: Optional[Dict[str, Argument]] = None,
                     type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a `call_method` `Node` into the `Graph`. A `call_method` node
+        represents a call to a given method on the 0th element of `args.
+        For example, if args[0] is a `Node` representing a `Tensor`, then to call
+        `relu()` on that `Tensor`, pass `relu` to `method_name`.
+
+        `args` and `kwargs` represent the args and kwargs passed to the called
+        method, respectively.
+
+        The same insertion point and type expression rules apply for this method
+        as `Graph.create_node`.
+        """
         return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
 
     def call_function(self,
@@ -248,10 +376,22 @@
                       args: Optional[Tuple[Argument, ...]] = None,
                       kwargs: Optional[Dict[str, Argument]] = None,
                       type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a `call_function` `Node` into the `Graph`. A `call_function` node
+        represents a call to a Python callable, specified by `the_function`. `the_function`
+        can be any PyTorch operator, Python function, or member of the `builtins`
+        or `operator` namespaces.
+
+        `args` and `kwargs` represent the args and kwargs passed to the called
+        method, respectively.
+
+        The same insertion point and type expression rules apply for this method
+        as `Graph.create_node`.
+        """
         return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
 
     def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
-        """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
+        """ Copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
             to the graph of self. Example:
 
             g : torch.fx.Graph = ...
@@ -281,6 +421,14 @@
         return self.create_node(node.op, node.target, args, kwargs, name, node.type)
 
     def output(self, result: Argument, type_expr: Optional[Any] = None):
+        """
+        Insert an `output` `Node` into the `Graph`. An `output` node represents
+        a `return` statement in the Python code. `result` is the value that should
+        be returned.
+
+        The same insertion point and type expression rules apply for this method
+        as `Graph.create_node`.
+        """
         return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
 
     def _name(self, target: Target) -> str:
@@ -294,7 +442,7 @@
         op = op.replace('.', '_')
         # delete all characters that are illegal in a Python identifier
         op = re.sub('[^0-9a-zA-Z_]+', '_', op)
-        op = snake_case(op)
+        op = _snake_case(op)
         if op[0].isdigit():
             op = f'_{op}'
 
@@ -318,6 +466,9 @@
         return f'{op}_{i}'
 
     def python_code(self, root_module: str) -> str:
+        """
+        Turn this `Graph` into valid Python code.
+        """
         free_vars: List[str] = []
         modules_used : Set[str] = set()
         body: List[str] = []
@@ -405,6 +556,10 @@
         return fn_code
 
     def __str__(self) -> str:
+        """
+        Print a human-readable (not machine-readable) string representation
+        of this Graph
+        """
         placeholder_names : List[str] = []
         # This is a one-element array just so `format_node` can modify the closed
         # over value
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 4a7ab8c..3525e18 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -172,11 +172,19 @@
 
     @property
     def graph(self):
+        """
+        Return the `Graph` underlying this `GraphModule`
+        """
         return self._graph
 
     @graph.setter
-    def graph(self, val) -> None:
-        self._graph = val
+    def graph(self, g) -> None:
+        """
+        Set the underlying `Graph` for this `GraphModule`. This will internally
+        recompile the `GraphModule` so that the generated `forward()` function
+        corresponds to `g`
+        """
+        self._graph = g
         self.recompile()
 
     def recompile(self) -> None:
@@ -204,6 +212,13 @@
         cls.__call__ = wrapped_call
 
     def __reduce__(self):
+        """
+        Serialization of GraphModule. We serialize only the generated code, not
+        the underlying `Graph`. This is because `Graph` does not have on-disk
+        backward-compatibility guarantees, whereas Python source code does.
+        On the deserialization side, we symbolically trace through the generated
+        code to regenerate the underlying `Graph`
+        """
         dict_without_graph = self.__dict__.copy()
         del dict_without_graph['_graph']
         return (deserialize_graphmodule, (dict_without_graph,))
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 118b32f..dd304a8 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -59,10 +59,16 @@
 
     @property
     def next(self) -> 'Node':
+        """
+        Get the next node in the linked list
+        """
         return self._next
 
     @property
     def prev(self) -> 'Node':
+        """
+        Get the previous node in the linked list
+        """
         return self._prev
 
     def prepend(self, x: 'Node'):
@@ -96,18 +102,38 @@
 
     @property
     def args(self) -> Tuple[Argument, ...]:
+        """
+        Return the tuple of arguments to this Node. The interpretation of arguments
+        depends on the node's opcode. See the `fx.Graph` docstring for more
+        information.
+        """
         return self._args
 
     @args.setter
     def args(self, a : Tuple[Argument, ...]):
+        """
+        Set the tuple of arguments to this Node. The interpretation of arguments
+        depends on the node's opcode. See the `fx.Graph` docstring for more
+        information.
+        """
         self._update_args_kwargs(map_arg(a, lambda x: x), self._kwargs)  # type: ignore
 
     @property
     def kwargs(self) -> Dict[str, Argument]:
+        """
+        Return the dict of kwargs to this Node. The interpretation of arguments
+        depends on the node's opcode. See the `fx.Graph` docstring for more
+        information.
+        """
         return self._kwargs
 
     @kwargs.setter
     def kwargs(self, k : Dict[str, Argument]):
+        """
+        Set the dict of kwargs to this Node. The interpretation of arguments
+        depends on the node's opcode. See the `fx.Graph` docstring for more
+        information.
+        """
         self._update_args_kwargs(self._args, map_arg(k, lambda x: x))  # type: ignore
 
     def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
@@ -151,7 +177,7 @@
 
 
 def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
-    """ apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+    """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
     if isinstance(a, tuple):
         return tuple(map_arg(elem, fn) for elem in a)
     if isinstance(a, list):
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 20566bb..a75d5ff 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -38,10 +38,34 @@
     # instead, let's make python think that args and kwargs are normal variables
 
 class Tracer(TracerBase):
+    """
+    `Tracer` is the class that implements the symbolic tracing functionality
+    of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent
+    to `Tracer().trace(m)`.
+
+    Tracer can be subclassed to override various behaviors of the tracing
+    process. The different behaviors that can be overridden are described
+    in the docstrings of the methods on this class.
+    """
     def __init__(self):
         super().__init__()
 
     def create_arg(self, a: Any) -> Argument:
+        """
+        A method to specify the behavior of tracing when preparing values to
+        be used as arguments to nodes in the `Graph`.
+
+        By default, the behavior includes:
+        - Iterate through collection types (e.g. tuple, list, dict) and recursively
+          call `create_args` on the elements.
+        - Given a Proxy object, return a reference to the underlying IR `Node`
+        - Given a non-Proxy Tensor object, emit IR for various cases:
+            - For a Parameter, emit a `get_attr` node referring to that Parameter
+            - For a non-Parameter Tensor, store the Tensor away in a special
+              attribute referring to that attribute.
+
+        This method can be overridden to support more types.
+        """
         # The base tracer is used to construct Graphs when there is no associated
         # module hierarchy, so it can never create parameter references.
         # The default tracer adds the ability to refer to parameters when
@@ -95,19 +119,43 @@
         """
         return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
 
-    def path_of_module(self, mod):
+    def path_of_module(self, mod) -> str:
+        """
+        Helper method to find the qualified name of `mod` in the Module hierarchy
+        of `root`. For example, if `root` has a submodule named `foo`, which has
+        a submodule named `bar`, passing `bar` into this function will return
+        the string "foo.bar".
+        """
         for n, p in self.root.named_modules():
             if mod is p:
                 return n
         raise NameError('module is not installed as a submodule')
 
     def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
+        """
+        Method that specifies the behavior of this `Tracer` when it encounters
+        a call to an `nn.Module` instance.
+
+        By default, the behavior is to check if the called module is a leaf module
+        via `is_leaf_module`. If it is, emit a `call_module` node referring to
+        `m` in the `Graph`. Otherwise, call the `Module` normally, tracing through
+        the operations in its `forward` function.
+
+        This method can be overridden to--for example--create nested traced
+        GraphModules, or any other behavior you would want while tracing across
+        `Module` boundaries.
+        """
         module_qualified_name = self.path_of_module(m)
         if not self.is_leaf_module(m, module_qualified_name):
             return forward(*args, **kwargs)
         return self.create_proxy('call_module', module_qualified_name, args, kwargs)
 
     def create_args_for_root(self, root_fn, is_module):
+        """
+        Create `placeholder` nodes corresponding to the signature of the `root`
+        Module. This method introspects `root`'s signature and emits those
+        nodes accordingly, also supporting *args and **kwargs.
+        """
         # In some cases, a function or method has been decorated with a wrapper
         # defined via `functools.wraps`. In this case, the outer code object
         # will likely not contain the actual parameters we care about, so unwrap
@@ -149,6 +197,10 @@
         return root_fn, args
 
     def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
+        """
+        Trace `root` and return the corresponding FX `Graph` representation. `root`
+        can either be an `nn.Module` instance or a Python callable.
+        """
         if isinstance(root, torch.nn.Module):
             self.root = root
             fn = type(root).forward
@@ -211,12 +263,11 @@
         return self.graph
 
 
-# Symbolic tracing API
-#
-# Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`
-# constructed by recording operations seen while tracing through `root`.
-#
-# Args:
-#   - root - the `nn.Module` instance to trace
 def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
+    """
+    Symbolic tracing API
+
+     Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`
+     constructed by recording operations seen while tracing through `root`.
+    """
     return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root))