[FX] Preserve type annotations on generated code in Graph (#45880)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45880
Test Plan: Imported from OSS
Reviewed By: dzhulgakov
Differential Revision: D24127303
Pulled By: jamesr66a
fbshipit-source-id: 3a042bcfb0bf9f58ac318cc814dfc3cca683c7f8
diff --git a/test/test_fx.py b/test/test_fx.py
index 76217fc..94bea70 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -15,7 +15,7 @@
from fx.quantization import Quantizer
-from typing import Any, Callable, Dict, Optional, Tuple, Union
+from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase
@@ -33,6 +33,10 @@
def a_non_torch_leaf(a, b):
return a + b
+class Pair(NamedTuple):
+ x : torch.Tensor
+ y : torch.Tensor
+
class TestFX(JitTestCase):
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
"""Check that an nn.Module's results match the GraphModule version
@@ -131,7 +135,8 @@
# Custom delegate to disallow in-place tensor operations
class NoMutableCallTracer(Tracer):
def create_node(self, kind : str, target : Union[str, Callable],
- args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node:
+ args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
+ type_expr : Optional[Any] = None) -> Node:
name = target if isinstance(target, str) else torch.typename(target)
if name[-1] == '_':
raise RuntimeError('In-place operations are not supported')
@@ -448,7 +453,8 @@
def test_node_tagging(self):
class TaggingTracer(Tracer):
def create_node(self, kind : str, target : Union[str, Callable],
- args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node:
+ args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
+ type_expr : Optional[Any] = None) -> Node:
n = super().create_node(kind, target, args, kwargs, name)
n.tag = 'foo'
return n
@@ -765,6 +771,26 @@
# Test shape propogation and make sure results match actual
self.assertEqual(output_shape, ref_out.shape)
+ def test_fn_type_annotations(self):
+ class Foo(torch.nn.Module):
+ def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
+ return {'a': p.x + p.y + z + i}
+
+ foo_scripted = torch.jit.script(Foo())
+ foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
+
+ fxed = symbolic_trace(Foo())
+ fxed_scripted = torch.jit.script(fxed)
+ fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
+
+ def test_typename_print(self):
+ graph : torch.fx.Graph = torch.fx.Graph()
+ x : torch.fx.Node = graph.create_node('placeholder', 'x')
+ b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
+ type_expr=List[float])
+ output : torch.fx.Node = graph.output(b)
+ self.assertTrue('typing.List[float]' in str(graph))
+
def test_find_single_partition(self):
class testModule(torch.nn.Module):
def forward(self, a, b):
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 600fcb2..9994bd8 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -3,6 +3,7 @@
from typing import Callable, Any, List, Dict, Optional, Tuple, Set
import builtins
import torch
+import types
import keyword
import re
@@ -52,6 +53,29 @@
r = f'{r}.{e}'
return r
+# Borrowed from CPython typing module
+# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
+def _type_repr(obj):
+ """Return the repr() of an object, special-casing types (internal helper).
+ If obj is a type, we return a shorter version than the default
+ type.__repr__, based on the module and qualified name, which is
+ typically enough to uniquely identify a type. For everything
+ else, we fall back on repr(obj).
+ """
+ # HACK: In Python 3.6, type aliases from `typing` are instances of `type`, but in
+ # later Python versions, type aliases are not instances of `type`!! We want
+ # all type aliases to fall through to `repr`, so if we have a type that is
+ # in the module typing, don't go down this path.
+ if isinstance(obj, type) and obj.__module__ != 'typing':
+ if obj.__module__ == 'builtins':
+ return obj.__qualname__
+ return f'{obj.__module__}.{obj.__qualname__}'
+ if obj is ...:
+ return('...')
+ if isinstance(obj, types.FunctionType):
+ return obj.__name__
+ return repr(obj)
+
class insert_before:
def __init__(self, n : Node):
self.n = n
@@ -65,6 +89,9 @@
class Graph:
def __init__(self):
+ """
+ Construct an empty Graph.
+ """
self._nodes : List[Node] = []
self._used_names : Dict[str, int] = {} # base name -> number
self._insert_point : Optional[Node] = None
@@ -90,12 +117,13 @@
def create_node(self, op: str, target: Target,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
- name: Optional[str] = None) -> Node:
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None) -> Node:
assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
sanitized_name = self._register_name_used(name) if name is not None else self._name(target)
- n = Node(self, sanitized_name, op, target, args, kwargs)
+ n = Node(self, sanitized_name, op, target, args, kwargs, type_expr)
if self._insert_point is not None:
before_idx = self._nodes.index(self._insert_point)
self._nodes.insert(before_idx, n)
@@ -130,29 +158,32 @@
self._nodes.pop(idx)
# sugar for above when you know the op
- def placeholder(self, name: str) -> Node:
- return self.create_node('placeholder', name)
+ def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
+ return self.create_node('placeholder', name, type_expr=type_expr)
- def get_attr(self, name: str) -> Node:
- return self.create_node('get_attr', name)
+ def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node:
+ return self.create_node('get_attr', name, type_expr=type_expr)
def call_module(self,
module_name: str,
args: Optional[Tuple[Argument, ...]] = None,
- kwargs: Optional[Dict[str, Argument]] = None) -> Node:
- return self.create_node('call_module', module_name, args, kwargs)
+ kwargs: Optional[Dict[str, Argument]] = None,
+ type_expr: Optional[Any] = None) -> Node:
+ return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
def call_method(self,
method_name: str,
args: Optional[Tuple[Argument, ...]] = None,
- kwargs: Optional[Dict[str, Argument]] = None) -> Node:
- return self.create_node('call_method', method_name, args, kwargs)
+ kwargs: Optional[Dict[str, Argument]] = None,
+ type_expr: Optional[Any] = None) -> Node:
+ return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
def call_function(self,
the_function: Callable[..., Any],
args: Optional[Tuple[Argument, ...]] = None,
- kwargs: Optional[Dict[str, Argument]] = None) -> Node:
- return self.create_node('call_function', the_function, args, kwargs)
+ kwargs: Optional[Dict[str, Argument]] = None,
+ type_expr: Optional[Any] = None) -> Node:
+ return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
""" copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
@@ -181,10 +212,10 @@
except ValueError:
pass
name = self._name(sanitized_name)
- return self.create_node(node.op, node.target, args, kwargs, name)
+ return self.create_node(node.op, node.target, args, kwargs, name, node.type)
- def output(self, result: Argument):
- return self.create_node(op='output', target='output', args=(result,))
+ def output(self, result: Argument, type_expr: Optional[Any] = None):
+ return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
def _name(self, target: Target) -> str:
if callable(target):
@@ -224,10 +255,23 @@
free_vars: List[str] = []
modules_used : Set[str] = set()
body: List[str] = []
+ maybe_return_annotation : str = ''
+
+ def register_modules_used(qualified_name : str):
+ if '.' in qualified_name:
+ module_name = qualified_name.split('.', maxsplit=1)[0]
+ modules_used.add(module_name)
+
+ def type_repr(o : Any):
+ typename = _type_repr(o)
+ register_modules_used(typename)
+ return typename
+
for node in self._nodes:
if node.op == 'placeholder':
assert isinstance(node.target, str)
- free_vars.append(node.target)
+ maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
+ free_vars.append(f'{node.target}{maybe_type_annotation}')
raw_name = node.target.replace('*', '')
if raw_name != node.name:
body.append(f'{node.name} = {raw_name}\n')
@@ -246,9 +290,7 @@
body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n')
continue
qualified_name = _qualified_name(node.target)
- if '.' in qualified_name:
- module_name = qualified_name.split('.', maxsplit=1)[0]
- modules_used.add(module_name)
+ register_modules_used(qualified_name)
if qualified_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
@@ -267,6 +309,8 @@
body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
continue
elif node.op == 'output':
+ if node.type is not None:
+ maybe_return_annotation = f" -> {type_repr(node.type)}"
body.append(f'return {node.args[0]}')
continue
raise NotImplementedError(f'node: {node.op} {node.target}')
@@ -277,13 +321,17 @@
code = '\n'.join(' ' + line for line in code.split('\n')) + '\n'
fn_code = f"""\
{import_block}
-def forward(self, {', '.join(free_vars)}):
+def forward(self, {', '.join(free_vars)}){maybe_return_annotation}:
{code}
"""
+
return fn_code
def __str__(self) -> str:
placeholder_names : List[str] = []
+ # This is a one-element array just so `format_node` can modify the closed
+ # over value
+ maybe_return_typename : List[str] = ['']
def format_arg(arg) -> str:
if isinstance(arg, list):
@@ -305,20 +353,26 @@
def format_node(n : Node) -> Optional[str]:
if n.op == 'placeholder':
assert isinstance(n.target, str)
- placeholder_names.append(n.target)
+ arg_str = n.target
+ arg_str += arg_str + f': {_type_repr(n.type)}' if n.type is not None else ''
+ placeholder_names.append(arg_str)
return None
elif n.op == 'get_attr':
- return f'%{n.name} : [#users={len(n.users)}] = self.{n.target}'
+ maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
+ return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = self.{n.target}'
elif n.op == 'output':
+ if n.type is not None:
+ maybe_return_typename[0] = f' -> {_type_repr(n.type)}'
return f'return {n.args[0]}'
else:
- return f'%{n.name} : [#users={len(n.users)}] = {n.op}[target={n.target}](' \
+ maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
+ return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = {n.op}[target={n.target}](' \
f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'
node_strs = [format_node(node) for node in self._nodes]
param_str = ', '.join(placeholder_names)
- s = f'graph({param_str}):'
+ s = f'graph({param_str}){maybe_return_typename[0]}:'
for node_str in node_strs:
if node_str:
s += '\n ' + node_str
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 458e1d3..7d35483 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -21,7 +21,8 @@
class Node:
def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
- args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None:
+ args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
+ type : Optional[Any] = None) -> None:
self.graph = graph
self.name = name # unique name of value being created
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']
@@ -39,6 +40,17 @@
#
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
self.users : Dict['Node', None] = {}
+ # Type expression representing the output value of this node.
+ # This should contain the same class of Type objects that would appear
+ # as type annotations for function inputs/outputs.
+ #
+ # For placeholder nodes, this value will be used to type-annotate the
+ # generated function parameters.
+ # For the return ndoe, this value will be used to type-annotate the
+ # generated function return type. (Note this is a special case. `return`
+ # does not produce a value, it's more of a notation. Thus, this value
+ # describes the type of args[0] in the `return` node.
+ self.type : Optional[Any] = type
@property
def args(self) -> Tuple[Argument, ...]:
diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py
index 90593c4..20d7178 100644
--- a/torch/fx/proxy.py
+++ b/torch/fx/proxy.py
@@ -11,7 +11,8 @@
graph: Graph
def create_node(self, kind : str, target : Union[str, Callable],
- args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None) -> Node:
+ args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
+ type_expr : Optional[Any] = None) -> Node:
"""
Inserts a graph node given target, args, kwargs, and name.
@@ -19,7 +20,7 @@
modification of values used in node creation. For example, one might
want to disallow in-place operations from being recorded.
"""
- return self.graph.create_node(kind, target, args, kwargs, name)
+ return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
def create_arg(self, a: Any) -> Argument:
"""
@@ -65,12 +66,13 @@
# Unwrap the proxies inside args, and kwargs, create the resulting node
# and then wrap the result in a proxy.
-def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any], name=None):
+def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any],
+ name=None, type_expr : Optional[Any] = None):
args = tracer.create_arg(args_)
kwargs = tracer.create_arg(kwargs_)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
- rn = tracer.create_node(op, target, args, kwargs, name)
+ rn = tracer.create_node(op, target, args, kwargs, name, type_expr)
return Proxy(rn, tracer)
class Proxy:
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index 7c295f6..7c5b2d7 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -121,18 +121,23 @@
def trace(self, root: torch.nn.Module) -> Graph:
self.root = root
+ fn = type(root).forward
self.graph = Graph()
- fn = type(root).forward
assert isinstance(fn, FunctionType)
co = fn.__code__
total_args = co.co_argcount + co.co_kwonlyargcount
names_iter = iter(co.co_varnames)
next(names_iter) # skip self
args : List[Any] = [root]
- args.extend(self._proxy_placeholder(next(names_iter)) for name in range(1, total_args))
+
+ def make_proxy_placeholder():
+ name = next(names_iter)
+ return self._proxy_placeholder(name, fn.__annotations__.get(name, None))
+ args.extend(make_proxy_placeholder() for _ in range(1, total_args))
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
+ # TODO: type annotations for *args and **kwargs
if co.co_flags & inspect.CO_VARARGS:
args.append(self._proxy_placeholder('*' + next(names_iter)))
if co.co_flags & inspect.CO_VARKEYWORDS:
@@ -149,13 +154,14 @@
return _create_proxy(self, 'call_module', module_qualified_name, args, kwargs)
try:
torch.nn.Module.__call__ = module_call_wrapper
- self.create_node('output', 'output', (self.create_arg(fn(*args)),), {})
+ self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
+ type_expr=fn.__annotations__.get('return', None))
finally:
torch.nn.Module.__call__ = orig_call
return self.graph
- def _proxy_placeholder(self, name: str) -> Proxy:
- return Proxy(self.create_node('placeholder', name, (), {}), self)
+ def _proxy_placeholder(self, name: str, type_expr: Optional[Any] = None) -> Proxy:
+ return Proxy(self.create_node('placeholder', name, (), {}, type_expr=type_expr), self)
# Symbolic tracing API
#