[FX] Preserve type annotations on generated code in Graph (#45880)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45880

Test Plan: Imported from OSS

Reviewed By: dzhulgakov

Differential Revision: D24127303

Pulled By: jamesr66a

fbshipit-source-id: 3a042bcfb0bf9f58ac318cc814dfc3cca683c7f8
diff --git a/test/test_fx.py b/test/test_fx.py
index 76217fc..94bea70 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -15,7 +15,7 @@
 
 from fx.quantization import Quantizer
 
-from typing import Any, Callable, Dict, Optional, Tuple, Union
+from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
 from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS
 from torch.testing._internal.jit_utils import JitTestCase
 
@@ -33,6 +33,10 @@
 def a_non_torch_leaf(a, b):
     return a + b
 
+class Pair(NamedTuple):
+    x : torch.Tensor
+    y : torch.Tensor
+
 class TestFX(JitTestCase):
     def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
         """Check that an nn.Module's results match the GraphModule version
@@ -131,7 +135,8 @@
         # Custom delegate to disallow in-place tensor operations
         class NoMutableCallTracer(Tracer):
             def create_node(self, kind : str, target : Union[str, Callable],
-                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node:
+                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
+                            type_expr : Optional[Any] = None) -> Node:
                 name = target if isinstance(target, str) else torch.typename(target)
                 if name[-1] == '_':
                     raise RuntimeError('In-place operations are not supported')
@@ -448,7 +453,8 @@
     def test_node_tagging(self):
         class TaggingTracer(Tracer):
             def create_node(self, kind : str, target : Union[str, Callable],
-                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node:
+                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
+                            type_expr : Optional[Any] = None) -> Node:
                 n = super().create_node(kind, target, args, kwargs, name)
                 n.tag = 'foo'
                 return n
@@ -765,6 +771,26 @@
         # Test shape propogation and make sure results match actual
         self.assertEqual(output_shape, ref_out.shape)
 
+    def test_fn_type_annotations(self):
+        class Foo(torch.nn.Module):
+            def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
+                return {'a': p.x + p.y + z + i}
+
+        foo_scripted = torch.jit.script(Foo())
+        foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
+
+        fxed = symbolic_trace(Foo())
+        fxed_scripted = torch.jit.script(fxed)
+        fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
+
+    def test_typename_print(self):
+        graph : torch.fx.Graph = torch.fx.Graph()
+        x : torch.fx.Node = graph.create_node('placeholder', 'x')
+        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
+                                              type_expr=List[float])
+        output : torch.fx.Node = graph.output(b)
+        self.assertTrue('typing.List[float]' in str(graph))
+
     def test_find_single_partition(self):
         class testModule(torch.nn.Module):
             def forward(self, a, b):
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 600fcb2..9994bd8 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -3,6 +3,7 @@
 from typing import Callable, Any, List, Dict, Optional, Tuple, Set
 import builtins
 import torch
+import types
 import keyword
 import re
 
@@ -52,6 +53,29 @@
             r = f'{r}.{e}'
     return r
 
+# Borrowed from CPython typing module
+# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
+def _type_repr(obj):
+    """Return the repr() of an object, special-casing types (internal helper).
+    If obj is a type, we return a shorter version than the default
+    type.__repr__, based on the module and qualified name, which is
+    typically enough to uniquely identify a type.  For everything
+    else, we fall back on repr(obj).
+    """
+    # HACK: In Python 3.6, type aliases from `typing` are instances of `type`, but in
+    # later Python versions, type aliases are not instances of `type`!! We want
+    # all type aliases to fall through to `repr`, so if we have a type that is
+    # in the module typing, don't go down this path.
+    if isinstance(obj, type) and obj.__module__ != 'typing':
+        if obj.__module__ == 'builtins':
+            return obj.__qualname__
+        return f'{obj.__module__}.{obj.__qualname__}'
+    if obj is ...:
+        return('...')
+    if isinstance(obj, types.FunctionType):
+        return obj.__name__
+    return repr(obj)
+
 class insert_before:
     def __init__(self, n : Node):
         self.n = n
@@ -65,6 +89,9 @@
 
 class Graph:
     def __init__(self):
+        """
+        Construct an empty Graph.
+        """
         self._nodes : List[Node] = []
         self._used_names : Dict[str, int] = {}  # base name -> number
         self._insert_point : Optional[Node] = None
@@ -90,12 +117,13 @@
     def create_node(self, op: str, target: Target,
                     args: Optional[Tuple[Argument, ...]] = None,
                     kwargs: Optional[Dict[str, Argument]] = None,
-                    name: Optional[str] = None) -> Node:
+                    name: Optional[str] = None,
+                    type_expr: Optional[Any] = None) -> Node:
         assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
         args = () if args is None else args
         kwargs = {} if kwargs is None else kwargs
         sanitized_name = self._register_name_used(name) if name is not None else self._name(target)
-        n = Node(self, sanitized_name, op, target, args, kwargs)
+        n = Node(self, sanitized_name, op, target, args, kwargs, type_expr)
         if self._insert_point is not None:
             before_idx = self._nodes.index(self._insert_point)
             self._nodes.insert(before_idx, n)
@@ -130,29 +158,32 @@
             self._nodes.pop(idx)
 
     # sugar for above when you know the op
-    def placeholder(self, name: str) -> Node:
-        return self.create_node('placeholder', name)
+    def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
+        return self.create_node('placeholder', name, type_expr=type_expr)
 
-    def get_attr(self, name: str) -> Node:
-        return self.create_node('get_attr', name)
+    def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node:
+        return self.create_node('get_attr', name, type_expr=type_expr)
 
     def call_module(self,
                     module_name: str,
                     args: Optional[Tuple[Argument, ...]] = None,
-                    kwargs: Optional[Dict[str, Argument]] = None) -> Node:
-        return self.create_node('call_module', module_name, args, kwargs)
+                    kwargs: Optional[Dict[str, Argument]] = None,
+                    type_expr: Optional[Any] = None) -> Node:
+        return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
 
     def call_method(self,
                     method_name: str,
                     args: Optional[Tuple[Argument, ...]] = None,
-                    kwargs: Optional[Dict[str, Argument]] = None) -> Node:
-        return self.create_node('call_method', method_name, args, kwargs)
+                    kwargs: Optional[Dict[str, Argument]] = None,
+                    type_expr: Optional[Any] = None) -> Node:
+        return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
 
     def call_function(self,
                       the_function: Callable[..., Any],
                       args: Optional[Tuple[Argument, ...]] = None,
-                      kwargs: Optional[Dict[str, Argument]] = None) -> Node:
-        return self.create_node('call_function', the_function, args, kwargs)
+                      kwargs: Optional[Dict[str, Argument]] = None,
+                      type_expr: Optional[Any] = None) -> Node:
+        return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
 
     def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
         """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
@@ -181,10 +212,10 @@
                 except ValueError:
                     pass
             name = self._name(sanitized_name)
-        return self.create_node(node.op, node.target, args, kwargs, name)
+        return self.create_node(node.op, node.target, args, kwargs, name, node.type)
 
-    def output(self, result: Argument):
-        return self.create_node(op='output', target='output', args=(result,))
+    def output(self, result: Argument, type_expr: Optional[Any] = None):
+        return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
 
     def _name(self, target: Target) -> str:
         if callable(target):
@@ -224,10 +255,23 @@
         free_vars: List[str] = []
         modules_used : Set[str] = set()
         body: List[str] = []
+        maybe_return_annotation : str = ''
+
+        def register_modules_used(qualified_name : str):
+            if '.' in qualified_name:
+                module_name = qualified_name.split('.', maxsplit=1)[0]
+                modules_used.add(module_name)
+
+        def type_repr(o : Any):
+            typename = _type_repr(o)
+            register_modules_used(typename)
+            return typename
+
         for node in self._nodes:
             if node.op == 'placeholder':
                 assert isinstance(node.target, str)
-                free_vars.append(node.target)
+                maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
+                free_vars.append(f'{node.target}{maybe_type_annotation}')
                 raw_name = node.target.replace('*', '')
                 if raw_name != node.name:
                     body.append(f'{node.name} = {raw_name}\n')
@@ -246,9 +290,7 @@
                     body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n')
                     continue
                 qualified_name = _qualified_name(node.target)
-                if '.' in qualified_name:
-                    module_name = qualified_name.split('.', maxsplit=1)[0]
-                    modules_used.add(module_name)
+                register_modules_used(qualified_name)
                 if qualified_name == 'getattr' and \
                    isinstance(node.args, tuple) and \
                    isinstance(node.args[1], str) and \
@@ -267,6 +309,8 @@
                 body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
                 continue
             elif node.op == 'output':
+                if node.type is not None:
+                    maybe_return_annotation = f" -> {type_repr(node.type)}"
                 body.append(f'return {node.args[0]}')
                 continue
             raise NotImplementedError(f'node: {node.op} {node.target}')
@@ -277,13 +321,17 @@
         code = '\n'.join('    ' + line for line in code.split('\n')) + '\n'
         fn_code = f"""\
 {import_block}
-def forward(self, {', '.join(free_vars)}):
+def forward(self, {', '.join(free_vars)}){maybe_return_annotation}:
 {code}
 """
+
         return fn_code
 
     def __str__(self) -> str:
         placeholder_names : List[str] = []
+        # This is a one-element array just so `format_node` can modify the closed
+        # over value
+        maybe_return_typename : List[str] = ['']
 
         def format_arg(arg) -> str:
             if isinstance(arg, list):
@@ -305,20 +353,26 @@
         def format_node(n : Node) -> Optional[str]:
             if n.op == 'placeholder':
                 assert isinstance(n.target, str)
-                placeholder_names.append(n.target)
+                arg_str = n.target
+                arg_str += arg_str + f': {_type_repr(n.type)}' if n.type is not None else ''
+                placeholder_names.append(arg_str)
                 return None
             elif n.op == 'get_attr':
-                return f'%{n.name} : [#users={len(n.users)}] = self.{n.target}'
+                maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
+                return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = self.{n.target}'
             elif n.op == 'output':
+                if n.type is not None:
+                    maybe_return_typename[0] = f' -> {_type_repr(n.type)}'
                 return f'return {n.args[0]}'
             else:
-                return f'%{n.name} : [#users={len(n.users)}] = {n.op}[target={n.target}](' \
+                maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
+                return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = {n.op}[target={n.target}](' \
                        f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'
 
 
         node_strs = [format_node(node) for node in self._nodes]
         param_str = ', '.join(placeholder_names)
-        s = f'graph({param_str}):'
+        s = f'graph({param_str}){maybe_return_typename[0]}:'
         for node_str in node_strs:
             if node_str:
                 s += '\n    ' + node_str
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 458e1d3..7d35483 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -21,7 +21,8 @@
 
 class Node:
     def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
-                 args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None:
+                 args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
+                 type : Optional[Any] = None) -> None:
         self.graph = graph
         self.name = name  # unique name of value being created
         assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']
@@ -39,6 +40,17 @@
         #
         # Is a dict to act as an "ordered set". Keys are significant, value dont-care
         self.users : Dict['Node', None] = {}
+        # Type expression representing the output value of this node.
+        # This should contain the same class of Type objects that would appear
+        # as type annotations for function inputs/outputs.
+        #
+        # For placeholder nodes, this value will be used to type-annotate the
+        # generated function parameters.
+        # For the return ndoe, this value will be used to type-annotate the
+        # generated function return type. (Note this is a special case. `return`
+        # does not produce a value, it's more of a notation. Thus, this value
+        # describes the type of args[0] in the `return` node.
+        self.type : Optional[Any] = type
 
     @property
     def args(self) -> Tuple[Argument, ...]:
diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py
index 90593c4..20d7178 100644
--- a/torch/fx/proxy.py
+++ b/torch/fx/proxy.py
@@ -11,7 +11,8 @@
     graph: Graph
 
     def create_node(self, kind : str, target : Union[str, Callable],
-                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None) -> Node:
+                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
+                    type_expr : Optional[Any] = None) -> Node:
         """
         Inserts a graph node given target, args, kwargs, and name.
 
@@ -19,7 +20,7 @@
         modification of values used in node creation. For example, one might
         want to disallow in-place operations from being recorded.
         """
-        return self.graph.create_node(kind, target, args, kwargs, name)
+        return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
 
     def create_arg(self, a: Any) -> Argument:
         """
@@ -65,12 +66,13 @@
 
 # Unwrap the proxies inside args, and kwargs, create the resulting node
 # and then wrap the result in a proxy.
-def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any], name=None):
+def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any],
+                  name=None, type_expr : Optional[Any] = None):
     args = tracer.create_arg(args_)
     kwargs = tracer.create_arg(kwargs_)
     assert isinstance(args, tuple)
     assert isinstance(kwargs, dict)
-    rn = tracer.create_node(op, target, args, kwargs, name)
+    rn = tracer.create_node(op, target, args, kwargs, name, type_expr)
     return Proxy(rn, tracer)
 
 class Proxy:
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 7c295f6..7c5b2d7 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -121,18 +121,23 @@
 
     def trace(self, root: torch.nn.Module) -> Graph:
         self.root = root
+        fn = type(root).forward
         self.graph = Graph()
 
-        fn = type(root).forward
         assert isinstance(fn, FunctionType)
         co = fn.__code__
         total_args = co.co_argcount + co.co_kwonlyargcount
         names_iter = iter(co.co_varnames)
         next(names_iter)  # skip self
         args : List[Any] = [root]
-        args.extend(self._proxy_placeholder(next(names_iter)) for name in range(1, total_args))
+
+        def make_proxy_placeholder():
+            name = next(names_iter)
+            return self._proxy_placeholder(name, fn.__annotations__.get(name, None))
+        args.extend(make_proxy_placeholder() for _ in range(1, total_args))
 
         if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
+            # TODO: type annotations for *args and **kwargs
             if co.co_flags & inspect.CO_VARARGS:
                 args.append(self._proxy_placeholder('*' + next(names_iter)))
             if co.co_flags & inspect.CO_VARKEYWORDS:
@@ -149,13 +154,14 @@
                 return _create_proxy(self, 'call_module', module_qualified_name, args, kwargs)
         try:
             torch.nn.Module.__call__ = module_call_wrapper
-            self.create_node('output', 'output', (self.create_arg(fn(*args)),), {})
+            self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
+                             type_expr=fn.__annotations__.get('return', None))
         finally:
             torch.nn.Module.__call__ = orig_call
         return self.graph
 
-    def _proxy_placeholder(self, name: str) -> Proxy:
-        return Proxy(self.create_node('placeholder', name, (), {}), self)
+    def _proxy_placeholder(self, name: str, type_expr: Optional[Any] = None) -> Proxy:
+        return Proxy(self.create_node('placeholder', name, (), {}, type_expr=type_expr), self)
 
 # Symbolic tracing API
 #