[FX][2/2] Make docstrings pretty when rendered (#48871)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48871

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D25351588

Pulled By: jamesr66a

fbshipit-source-id: 4c6fd341100594c204a35d6a3aab756e3e22297b
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 072aef6..ca4b8d6 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -148,26 +148,7 @@
             %topk_1 : [#users=1] = call_function[target=<built-in method topk of type object at 0x7ff2da9dc300>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
             return topk_1
 
-    The Node semantics are as follows:
-
-    - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
-      ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
-      denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
-      the function parameters (e.g. ``x``) in the graph printout.
-    - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
-      fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
-      ``args`` and ``kwargs`` are don't-care
-    - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
-      to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
-      following the Python calling convention
-    - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
-      as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
-      ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
-    - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
-      to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
-      *including the self argument*
-    - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
-      in the Graph printout.
+    For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
     """
     def __init__(self):
         """
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 1cc94be..fd8a4bc 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -21,8 +21,34 @@
 ]]
 
 class Node:
-    def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
-                 args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
+    """
+    ``Node`` is the data structure that represents individual operations within
+    a ``Graph``. For the most part, Nodes represent callsites to various entities,
+    such as operators, methods, and Modules (some exceptions include nodes that
+    specify function inputs and outputs). Each ``Node`` has a function specified
+    by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:
+
+    - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
+      ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
+      denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
+      the function parameters (e.g. ``x``) in the graph printout.
+    - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
+      fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
+      ``args`` and ``kwargs`` are don't-care
+    - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
+      to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
+      following the Python calling convention
+    - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
+      as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
+      ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
+    - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
+      to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
+      *including the self argument*
+    - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
+      in the Graph printout.
+    """
+    def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
+                 args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
                  type : Optional[Any] = None) -> None:
         self.graph = graph
         self.name = name  # unique name of value being created
@@ -60,23 +86,33 @@
     @property
     def next(self) -> 'Node':
         """
-        Get the next node in the linked list
+        Returns the next ``Node`` in the linked list of Nodes.
+
+        Returns:
+
+            The next ``Node`` in the linked list of Nodes.
         """
         return self._next
 
     @property
     def prev(self) -> 'Node':
         """
-        Get the previous node in the linked list
+        Returns the previous ``Node`` in the linked list of Nodes.
+
+        Returns:
+
+            The previous ``Node`` in the linked list of Nodes.
         """
         return self._prev
 
-    def prepend(self, x: 'Node'):
-        """Insert x before this node in the list of nodes in the graph.
-           Before: p -> self
-                   bx -> x -> ax
-           After:  p -> x -> self
-                   bx -> ax
+    def prepend(self, x: 'Node') -> None:
+        """
+        Insert x before this node in the list of nodes in the graph. Example::
+
+            Before: p -> self
+                    bx -> x -> ax
+            After:  p -> x -> self
+                    bx -> ax
 
         Args:
             x (Node): The node to put before this node. Must be a member of the same graph.
@@ -87,8 +123,9 @@
         p._next, x._prev = x, p
         x._next, self._prev = self, x
 
-    def append(self, x: 'Node'):
-        """Insert x after this node in the list of nodes in the graph.
+    def append(self, x: 'Node') -> None:
+        """
+        Insert x after this node in the list of nodes in the graph.
         Equvalent to ``self.next.prepend(x)``
 
         Args:
@@ -103,9 +140,12 @@
     @property
     def args(self) -> Tuple[Argument, ...]:
         """
-        Return the tuple of arguments to this Node. The interpretation of arguments
-        depends on the node's opcode. See the ``fx.Graph`` docstring for more
+        The tuple of arguments to this ``Node``. The interpretation of arguments
+        depends on the node's opcode. See the :class:`Node` docstring for more
         information.
+
+        Assignment to this property is allowed. All accounting of uses and users
+        is updated automatically on assignment.
         """
         return self._args
 
@@ -121,9 +161,12 @@
     @property
     def kwargs(self) -> Dict[str, Argument]:
         """
-        Return the dict of kwargs to this Node. The interpretation of arguments
-        depends on the node's opcode. See the ``fx.Graph`` docstring for more
+        The dict of keyword arguments to this ``Node``. The interpretation of arguments
+        depends on the node's opcode. See the :class:`Node` docstring for more
         information.
+
+        Assignment to this property is allowed. All accounting of uses and users
+        is updated automatically on assignment.
         """
         return self._kwargs
 
@@ -141,7 +184,12 @@
         """
         Return all Nodes that are inputs to this Node. This is equivalent to
         iterating over ``args`` and ``kwargs`` and only collecting the values that
-        are Nodes
+        are Nodes.
+
+        Returns:
+
+            List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
+            ``Node``, in that order.
         """
         all_nodes : List['Node'] = []
         map_arg(self.args, lambda n: all_nodes.append(n))
@@ -149,6 +197,9 @@
         return all_nodes
 
     def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
+        """
+        This API is internal. Do *not* call it directly.
+        """
         self._args = new_args
         self._kwargs = new_kwargs
 
@@ -168,7 +219,14 @@
     def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']:
         """
         Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
-        Returns the list of nodes on which this change was made.
+
+        Args:
+
+            replace_with (Node): The node to replace all uses of ``self`` with.
+
+        Returns:
+
+            The list of Nodes on which this change was made.
         """
         to_process = list(self.users)
         for use_node in to_process:
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 6bdc8dd..69e3c70 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -1,6 +1,6 @@
 import inspect
 from types import CodeType, FunctionType
-from typing import Any, Dict, Optional, List, Callable, Union
+from typing import Any, Dict, Optional, Tuple, List, Callable, Union
 import torch
 from torch._C import ScriptObject  # type: ignore
 
@@ -51,21 +51,31 @@
     def __init__(self):
         super().__init__()
 
-    def create_arg(self, a: Any) -> Argument:
+    def create_arg(self, a: Any) -> 'Argument':
         """
         A method to specify the behavior of tracing when preparing values to
         be used as arguments to nodes in the ``Graph``.
 
         By default, the behavior includes:
-        - Iterate through collection types (e.g. tuple, list, dict) and recursively
-          call ``create_args`` on the elements.
-        - Given a Proxy object, return a reference to the underlying IR ``Node``
-        - Given a non-Proxy Tensor object, emit IR for various cases:
-            - For a Parameter, emit a ``get_attr`` node referring to that Parameter
-            - For a non-Parameter Tensor, store the Tensor away in a special
-              attribute referring to that attribute.
 
+        #. Iterate through collection types (e.g. tuple, list, dict) and recursively
+           call ``create_args`` on the elements.
+        #. Given a Proxy object, return a reference to the underlying IR ``Node``
+        #. Given a non-Proxy Tensor object, emit IR for various cases:
+
+            * For a Parameter, emit a ``get_attr`` node referring to that Parameter
+            * For a non-Parameter Tensor, store the Tensor away in a special
+              attribute referring to that attribute.
         This method can be overridden to support more types.
+
+        Args:
+
+            a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
+
+
+        Returns:
+
+            The value ``a`` converted into the appropriate ``Argument``
         """
         # The base tracer is used to construct Graphs when there is no associated
         # module hierarchy, so it can never create parameter references.
@@ -115,28 +125,32 @@
         their constituent ops are recorded, unless specified otherwise
         via this parameter.
 
-        Args
-        m - The module itself
-        module_qualified_name - The path to root of this module. For example,
-            if you have a module hierarchy where submodule ``foo`` contains
-            submodule ``bar``, which contains submodule ``baz``, that module will
-            appear with the qualified name ``foo.bar.baz`` here.
+        Args:
+            m (Module): The module being queried about
+            module_qualified_name (str): The path to root of this module. For example,
+                if you have a module hierarchy where submodule ``foo`` contains
+                submodule ``bar``, which contains submodule ``baz``, that module will
+                appear with the qualified name ``foo.bar.baz`` here.
         """
         return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
 
-    def path_of_module(self, mod) -> str:
+    def path_of_module(self, mod : torch.nn.Module) -> str:
         """
         Helper method to find the qualified name of ``mod`` in the Module hierarchy
         of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
         a submodule named ``bar``, passing ``bar`` into this function will return
         the string "foo.bar".
+
+        Args:
+
+            mod (str): The ``Module`` to retrieve the qualified name for.
         """
         for n, p in self.root.named_modules():
             if mod is p:
                 return n
         raise NameError('module is not installed as a submodule')
 
-    def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
+    def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Method that specifies the behavior of this ``Tracer`` when it encounters
         a call to an ``nn.Module`` instance.
@@ -149,6 +163,20 @@
         This method can be overridden to--for example--create nested traced
         GraphModules, or any other behavior you would want while tracing across
         ``Module`` boundaries.
+         ``Module`` boundaries.
+
+        Args:
+
+            m (Module): The module for which a call is being emitted
+            forward (Callable): The forward() method of the ``Module`` to be invoked
+            args (Tuple): args of the module callsite
+            kwargs (Dict): kwargs of the module callsite
+
+        Return:
+
+            The return value from the Module call. In the case that a ``call_module``
+            node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
+            value was returned from the ``Module`` invocation.
         """
         module_qualified_name = self.path_of_module(m)
         if not self.is_leaf_module(m, module_qualified_name):
@@ -205,6 +233,16 @@
         """
         Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
         can either be an ``nn.Module`` instance or a Python callable.
+
+
+        Args:
+
+            root (Union[Module, Callable]): Either a ``Module`` or a function to be
+                traced through.
+
+        Returns:
+
+            A ``Graph`` representing the semantics of the passed-in ``root``.
         """
         if isinstance(root, torch.nn.Module):
             self.root = root