blob: 078e9eea5f5ac693e7620fb8c4524130028ef5be [file] [log] [blame]
from .node import Node, Argument, Target, map_arg
from typing import Callable, Any, List, Dict, Optional, Tuple, Set
import builtins
import torch
import types
import keyword
import re
def _shadows_builtin_name(name: str) -> bool:
return name in builtins.__dict__ or name in keyword.kwlist
def _is_magic(x: str) -> bool:
return x.startswith('__') and x.endswith('__')
def snake_case(s: str) -> str:
return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_')
def _qualified_name(func: Callable[..., Any]) -> str:
# things like getattr just appear in builtins
if getattr(builtins, func.__name__, None) is func:
return func.__name__
name = func.__name__
module = _find_module_of_method(func)
module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
return f'{module}.{name}'
# this is fixed on master, WAR for 1.5
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
name = orig_method.__name__
module = orig_method.__module__
if module is not None:
return module
for guess in [torch, torch.nn.functional]:
if getattr(guess, name, None) is orig_method:
return guess.__name__
raise RuntimeError(f'cannot find module for {orig_method}')
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
args_s = ', '.join(repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
def _format_target(base: str, target: str) -> str:
elems = target.split('.')
r = base
for e in elems:
if not e.isidentifier():
r = f'getattr({r}, "{e}")'
else:
r = f'{r}.{e}'
return r
# Borrowed from CPython typing module
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
def _type_repr(obj):
"""Return the repr() of an object, special-casing types (internal helper).
If obj is a type, we return a shorter version than the default
type.__repr__, based on the module and qualified name, which is
typically enough to uniquely identify a type. For everything
else, we fall back on repr(obj).
"""
# HACK: In Python 3.6, type aliases from `typing` are instances of `type`, but in
# later Python versions, type aliases are not instances of `type`!! We want
# all type aliases to fall through to `repr`, so if we have a type that is
# in the module typing, don't go down this path.
if isinstance(obj, type) and obj.__module__ != 'typing':
if obj.__module__ == 'builtins':
return obj.__qualname__
return f'{obj.__module__}.{obj.__qualname__}'
if obj is ...:
return('...')
if isinstance(obj, types.FunctionType):
return obj.__name__
return repr(obj)
class _InsertPoint:
def __init__(self, graph, new_insert):
self.graph = graph
self.orig_insert, graph._insert = graph._insert, new_insert
def __enter__(self):
pass
def __exit__(self, type, value, tb):
self.graph._insert = self.orig_insert
class _node_list:
def __init__(self, graph: 'Graph', direction: str = '_next'):
assert direction in ['_next', '_prev']
self.graph = graph
self.direction = direction
def __len__(self):
return self.graph._len
def __iter__(self):
root, direction = self.graph._root, self.direction
cur = getattr(root, direction)
while cur is not root:
if not cur._erased:
yield cur
cur = getattr(cur, direction)
def __reversed__(self):
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
class Graph:
def __init__(self):
"""
Construct an empty Graph.
"""
self._root : Node = Node(self, '', 'root', '', (), {})
self._used_names : Dict[str, int] = {} # base name -> number
self._insert = self._root.prepend
self._len = 0
@property
def nodes(self):
return _node_list(self)
def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]:
"""
Append all nodes from graph `g` to this graph. `val_map` should be a dictionary
that maps nodes in `g` to nodes in `self. `val_map` will be populated with more
items by this function. Returns the equivalent output value of `g` with
Nodes switched to refer to nodes in `self`.
"""
for node in g.nodes:
if node in val_map:
continue
if node.op == 'output':
rv = map_arg(node.args[0], lambda n: val_map[n])
return rv
val_map[node] = self.node_copy(node, lambda n : val_map[n])
return None
def create_node(self, op: str, target: Target,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
name: Optional[str] = None,
type_expr: Optional[Any] = None) -> Node:
assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
sanitized_name = self._register_name_used(name) if name is not None else self._name(target)
n = Node(self, sanitized_name, op, target, args, kwargs, type_expr)
self._insert(n)
self._len += 1
return n
def erase_node(self, to_erase : Node):
"""
Erases the node `to_erase` from the `Graph`. Throws an exception if
there are still users of that node in the `Graph`.
"""
if len(to_erase.users) > 0:
raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
f'users in the graph: {to_erase.users}!')
to_erase._remove_from_list()
to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1
def inserting_before(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
then restore it when the with statement exits:
with g.inserting_before(n):
... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) # set the insert point permanently
Args:
n (Optional[Node]): The node before which to insert. If None this will insert before
the beginning of the entire graph.
Returns:
A resource manager that will restore the insert point on `__exit__`.
"""
if n is None:
return self.inserting_after(self._root)
assert n.graph == self, "Node to insert before is not in graph."
return _InsertPoint(self, n.prepend)
def inserting_after(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
then restore it when the with statement exits:
with g.inserting_after(n):
... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) # set the insert point permanently
Args:
n (Optional[Node]): The node before which to insert. If None this will insert after
the beginning of the entire graph.
Returns:
A resource manager that will restore the insert point on `__exit__`.
"""
if n is None:
return self.inserting_before(self._root)
assert n.graph == self, "Node to insert after is not in graph."
return _InsertPoint(self, n.append)
# sugar for create_node when you know the op
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
return self.create_node('placeholder', name, type_expr=type_expr)
def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node:
return self.create_node('get_attr', name, type_expr=type_expr)
def call_module(self,
module_name: str,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
def call_method(self,
method_name: str,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
def call_function(self,
the_function: Callable[..., Any],
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
type_expr: Optional[Any] = None) -> Node:
return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
""" copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
to the graph of self. Example:
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
"""
args = map_arg(node.args, arg_transform)
kwargs = map_arg(node.kwargs, arg_transform)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
if node.op == "placeholder":
# Placeholder names are user-visible, so they should be copied as-is without normalizing them.
name = node.name
else:
sanitized_name = node.name
if '_' in node.name:
base, maybe_idx = node.name.rsplit('_', 1)
try:
int(maybe_idx)
sanitized_name = base
except ValueError:
pass
name = self._name(sanitized_name)
return self.create_node(node.op, node.target, args, kwargs, name, node.type)
def output(self, result: Argument, type_expr: Optional[Any] = None):
return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
def _name(self, target: Target) -> str:
if callable(target):
op = target.__name__
else:
assert isinstance(target, str)
op = target
if _is_magic(op):
op = op[2:-2]
op = op.replace('.', '_')
# delete all characters that are illegal in a Python identifier
op = re.sub('[^0-9a-zA-Z_]+', '_', op)
op = snake_case(op)
if op[0].isdigit():
op = f'_{op}'
return self._register_name_used(op)
def _register_name_used(self, op : str) -> str:
"""
Even if a user provides us with a name, we must register that that
name is used to prevent duplication of names from further nodes as
well as ensure that the name provided does not shadow a builtin.
"""
if op not in self._used_names:
self._used_names[op] = 0
# Avoid shadowing PyTorch and Python builtins.
if not hasattr(torch, op) and \
not hasattr(torch.nn.functional, op) and \
not hasattr(torch.nn, op) and \
not _shadows_builtin_name(op):
return op
i = self._used_names[op] = self._used_names[op] + 1
return f'{op}_{i}'
def python_code(self, root_module: str) -> str:
free_vars: List[str] = []
modules_used : Set[str] = set()
body: List[str] = []
maybe_return_annotation : str = ''
def register_modules_used(qualified_name : str):
if '.' in qualified_name:
module_name = qualified_name.split('.', maxsplit=1)[0]
modules_used.add(module_name)
def type_repr(o : Any):
typename = _type_repr(o)
register_modules_used(typename)
return typename
for node in self.nodes:
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
free_vars.append(f'{node.target}{maybe_type_annotation}')
raw_name = node.target.replace('*', '')
if raw_name != node.name:
body.append(f'{node.name} = {raw_name}\n')
continue
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(
f'{node.name} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})\n')
continue
elif node.op == 'call_function':
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n')
continue
qualified_name = _qualified_name(node.target)
register_modules_used(qualified_name)
if qualified_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier():
# pretty print attribute access
body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n')
continue
body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n')
continue
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n')
continue
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
continue
elif node.op == 'output':
if node.type is not None:
maybe_return_annotation = f" -> {type_repr(node.type)}"
body.append(f'return {node.args[0]}')
continue
raise NotImplementedError(f'node: {node.op} {node.target}')
import_block = '\n'.join(f'import {name}' for name in sorted(modules_used))
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) + '\n'
fn_code = f"""\
{import_block}
def forward(self, {', '.join(free_vars)}){maybe_return_annotation}:
{code}
"""
return fn_code
def __str__(self) -> str:
placeholder_names : List[str] = []
# This is a one-element array just so `format_node` can modify the closed
# over value
maybe_return_typename : List[str] = ['']
def format_arg(arg) -> str:
if isinstance(arg, list):
items = ', '.join(format_arg(a) for a in arg)
return f'[{items}]'
elif isinstance(arg, tuple):
items = ', '.join(format_arg(a) for a in arg)
maybe_comma = ',' if len(arg) == 1 else ''
return f'({items}{maybe_comma})'
elif isinstance(arg, dict):
items_str = ', '.join(f'{k}: {format_arg(v)}' for k, v in arg.items())
return f'{{{items_str}}}'
if isinstance(arg, Node):
return '%' + str(arg)
else:
return str(arg)
def format_node(n : Node) -> Optional[str]:
if n.op == 'placeholder':
assert isinstance(n.target, str)
arg_str = n.target
arg_str += arg_str + f': {_type_repr(n.type)}' if n.type is not None else ''
placeholder_names.append(arg_str)
return None
elif n.op == 'get_attr':
maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = self.{n.target}'
elif n.op == 'output':
if n.type is not None:
maybe_return_typename[0] = f' -> {_type_repr(n.type)}'
return f'return {n.args[0]}'
else:
maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else ''
return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = {n.op}[target={n.target}](' \
f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'
node_strs = [format_node(node) for node in self.nodes]
param_str = ', '.join(placeholder_names)
s = f'graph({param_str}){maybe_return_typename[0]}:'
for node_str in node_strs:
if node_str:
s += '\n ' + node_str
return s
def lint(self, root : Optional[torch.nn.Module] = None):
"""
Runs various checks on this Graph to make sure it is well-formed. In
particular:
- Checks Nodes have correct ownership (owned by this graph)
- Checks Nodes appear in topological order
- If `root` is provided, checks that `target`s exist in `root`
"""
# Check topo order
def check_arg(arg : Node, n : Optional[Node] = None) -> None:
context_str = f' of Node \'{n}\' ' if n else ' '
if arg.graph is not self:
raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
f'but was used as an argument! If you are copying nodes from another graph, make '
f'sure to use `arg_transform` on node_copy() to remap values\n{self}')
if arg not in seen_values:
raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
seen_names : Set[str] = set()
seen_values : Set[Node] = set()
for node in self.nodes:
if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
if node.graph is not self:
raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
map_arg(node.args, lambda arg: check_arg(arg, node))
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
seen_values.add(node)
if node.name in seen_names:
raise RuntimeError(f'Node redefined name {node.name}!')
seen_names.add(node.name)
# Check targets are legit
if root:
for node in self.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
target_atoms = node.target.split('.')
m_itr = root
for i, atom in enumerate(target_atoms):
m_itr = getattr(m_itr, atom, None)
if m_itr is None:
seen_qualname = '.'.join(target_atoms[:i])
raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
f'{atom} of {seen_qualname}')
reflectable_magic_methods = {
'add': '{} + {}',
'sub': '{} - {}',
'mul': '{} * {}',
'floordiv': '{} // {}',
'truediv': '{} / {}',
'div': '{} / {}',
'mod': '{} % {}',
'pow': '{} ** {}',
'lshift': '{} << {}',
'rshift': '{} >> {}',
'and': '{} & {}',
'or': '{} | {}',
'xor': '{} ^ {}',
'getitem': '{}[{}]'
}
magic_methods = dict({
'eq': '{} == {}',
'ne': '{} != {}',
'lt': '{} < {}',
'gt': '{} > {}',
'le': '{} <= {}',
'ge': '{} >= {}',
'pos': '+{}',
'neg': '-{}',
'invert': '~{}'}, **reflectable_magic_methods)