blob: ed7618372b57df6473cf28a18b8d31eac876a554 [file] [log] [blame]
from .node import Node, Argument, Target, map_arg
from typing import Callable, Any, List, Dict, Optional, Tuple, Set
import builtins
import torch
import keyword
import re
def _shadows_builtin_name(name: str) -> bool:
return name in builtins.__dict__ or name in keyword.kwlist
def _is_magic(x: str) -> bool:
return x.startswith('__') and x.endswith('__')
def snake_case(s: str) -> str:
return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_')
def _qualified_name(func: Callable[..., Any]) -> str:
# things like getattr just appear in builtins
if getattr(builtins, func.__name__, None) is func:
return func.__name__
name = func.__name__
module = _find_module_of_method(func)
module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
return f'{module}.{name}'
# this is fixed on master, WAR for 1.5
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
name = orig_method.__name__
module = orig_method.__module__
if module is not None:
return module
for guess in [torch, torch.nn.functional]:
if getattr(guess, name, None) is orig_method:
return guess.__name__
raise RuntimeError(f'cannot find module for {orig_method}')
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
args_s = ', '.join(repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
def _format_target(base: str, target: str) -> str:
elems = target.split('.')
r = base
for e in elems:
if not e.isidentifier():
r = f'getattr({r}, "{e}")'
else:
r = f'{r}.{e}'
return r
class insert_before:
def __init__(self, n : Node):
self.n = n
def __enter__(self):
self.orig_insert_point = self.n.graph._insert_point
self.n.graph._insert_point = self.n
def __exit__(self, type, value, tb):
self.n.graph._insert_point = self.orig_insert_point
class Graph:
def __init__(self):
self._nodes : List[Node] = []
self._used_names : Dict[str, int] = {} # base name -> number
self._insert_point : Optional[Node] = None
@property
def nodes(self):
return tuple(self._nodes)
def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]:
"""
Append all nodes from graph `g` to this graph. `val_map` should be a dictionary
that maps nodes in `g` to nodes in `self. `val_map` will be populated with more
items by this function. Returns the equivalent output value of `g` with
Nodes switched to refer to nodes in `self`.
"""
for node in g._nodes:
if node.op == 'output':
rv = map_arg(node.args[0], lambda n: val_map[n])
return rv
val_map[node] = self.node_copy(node, lambda n : val_map[n])
return None
def _mark_uses(self, a: Argument):
def add_use(n: Node):
n.uses += 1
return n
map_arg(a, add_use)
def create_node(self, op: str, target: Target,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None,
name: Optional[str] = None) -> Node:
assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
self._mark_uses(args)
self._mark_uses(kwargs)
sanitized_name = self._register_name_used(name) if name is not None else self._name(target)
n = Node(self, sanitized_name, op, target, args, kwargs)
if self._insert_point is not None:
before_idx = self._nodes.index(self._insert_point)
self._nodes.insert(before_idx, n)
else:
self._nodes.append(n)
return n
def move_node_before(self, to_move : Node, before : Node):
"""
Move node `to_move` before `before` in the Graph. Both `Node` arguments
must be present in this graph.
"""
# TODO: Computationally inefficient
if to_move.graph != self or before.graph != self:
raise RuntimeError('Node arguments must belong to this Graph!')
node_idx = self._nodes.index(to_move)
before_idx = self._nodes.index(before)
self._nodes.insert(before_idx, self._nodes.pop(node_idx))
def erase_node(self, to_erase : Node):
"""
Erases the node `to_erase` from the `Graph`. Throws an exception if
there are still uses of that node in the `Graph`.
"""
if to_erase.uses > 0:
raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {to_erase.uses} uses in the graph!')
node_indices = [i for i, n in enumerate(self._nodes) if n == to_erase]
for idx in reversed(node_indices):
self._nodes.pop(idx)
# sugar for above when you know the op
def placeholder(self, name: str) -> Node:
return self.create_node('placeholder', name)
def get_attr(self, name: str) -> Node:
return self.create_node('get_attr', name)
def call_module(self,
module_name: str,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None) -> Node:
return self.create_node('call_module', module_name, args, kwargs)
def call_method(self,
method_name: str,
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None) -> Node:
return self.create_node('call_method', method_name, args, kwargs)
def call_function(self,
the_function: Callable[..., Any],
args: Optional[Tuple[Argument, ...]] = None,
kwargs: Optional[Dict[str, Argument]] = None) -> Node:
return self.create_node('call_function', the_function, args, kwargs)
def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node:
""" copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node
to the graph of self. Example:
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
"""
args = map_arg(node.args, arg_transform)
kwargs = map_arg(node.kwargs, arg_transform)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
if node.op == "placeholder":
# Placeholder names are user-visible, so they should be copied as-is without normalizing them.
name = node.name
else:
sanitized_name = node.name
if '_' in node.name:
base, maybe_idx = node.name.rsplit('_', 1)
try:
int(maybe_idx)
sanitized_name = base
except ValueError:
pass
name = self._name(sanitized_name)
return self.create_node(node.op, node.target, args, kwargs, name)
def output(self, result: Argument):
self._mark_uses(result)
return self.create_node(op='output', target='output', args=(result,))
def _name(self, target: Target) -> str:
if callable(target):
op = target.__name__
else:
assert isinstance(target, str)
op = target
if _is_magic(op):
op = op[2:-2]
op = op.replace('.', '_')
# delete all characters that are illegal in a Python identifier
op = re.sub('[^0-9a-zA-Z_]+', '_', op)
op = snake_case(op)
if op[0].isdigit():
op = f'_{op}'
return self._register_name_used(op)
def _register_name_used(self, op : str) -> str:
"""
Even if a user provides us with a name, we must register that that
name is used to prevent duplication of names from further nodes as
well as ensure that the name provided does not shadow a builtin.
"""
if op not in self._used_names:
self._used_names[op] = 0
# Avoid shadowing PyTorch and Python builtins.
if not hasattr(torch, op) and \
not hasattr(torch.nn.functional, op) and \
not hasattr(torch.nn, op) and \
not _shadows_builtin_name(op):
return op
i = self._used_names[op] = self._used_names[op] + 1
return f'{op}_{i}'
def python_code(self, root_module: str) -> str:
free_vars: List[str] = []
modules_used : Set[str] = set()
body: List[str] = []
for node in self._nodes:
if node.op == 'placeholder':
assert isinstance(node.target, str)
free_vars.append(node.target)
raw_name = node.target.replace('*', '')
if raw_name != node.name:
body.append(f'{node.name} = {raw_name}\n')
continue
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(
f'{node.name} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})\n')
continue
elif node.op == 'call_function':
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n')
continue
qualified_name = _qualified_name(node.target)
if '.' in qualified_name:
module_name = qualified_name.split('.', maxsplit=1)[0]
modules_used.add(module_name)
if qualified_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier():
# pretty print attribute access
body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n')
continue
body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n')
continue
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n')
continue
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
continue
elif node.op == 'output':
body.append(f'return {node.args[0]}')
continue
raise NotImplementedError(f'node: {node.op} {node.target}')
import_block = '\n'.join(f'import {name}' for name in sorted(modules_used))
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) + '\n'
fn_code = f"""\
{import_block}
def forward(self, {', '.join(free_vars)}):
{code}
"""
return fn_code
def __str__(self) -> str:
placeholder_names : List[str] = []
def format_arg(arg) -> str:
if isinstance(arg, list):
items = ', '.join(format_arg(a) for a in arg)
return f'[{items}]'
elif isinstance(arg, tuple):
items = ', '.join(format_arg(a) for a in arg)
maybe_comma = ',' if len(arg) == 1 else ''
return f'({items}{maybe_comma})'
elif isinstance(arg, dict):
items_str = ', '.join(f'{k}: {format_arg(v)}' for k, v in arg.items())
return f'{{{items_str}}}'
if isinstance(arg, Node):
return '%' + str(arg)
else:
return str(arg)
def format_node(n : Node) -> Optional[str]:
if n.op == 'placeholder':
assert isinstance(n.target, str)
placeholder_names.append(n.target)
return None
elif n.op == 'get_attr':
return f'%{n.name} : [uses={n.uses}] = self.{n.target}'
elif n.op == 'output':
return f'return {n.args[0]}'
else:
return f'%{n.name} : [uses={n.uses}] = {n.op}[target={n.target}](' \
f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'
node_strs = [format_node(node) for node in self._nodes]
param_str = ', '.join(placeholder_names)
s = f'graph({param_str}):'
for node_str in node_strs:
if node_str:
s += '\n ' + node_str
return s
def lint(self, root : Optional[torch.nn.Module] = None):
"""
Runs various checks on this Graph to make sure it is well-formed. In
particular:
- Checks Nodes have correct ownership (owned by this graph)
- Checks Nodes appear in topological order
- If `root` is provided, checks that `target`s exist in `root`
"""
# Check topo order
def check_arg(arg : Node, n : Optional[Node] = None) -> None:
context_str = f' of Node \'{n}\' ' if n else ' '
if arg.graph is not self:
raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
f'but was used as an argument! If you are copying nodes from another graph, make '
f'sure to use `arg_transform` on node_copy() to remap values\n{self}')
if arg not in seen_values:
raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
seen_names : Set[str] = set()
seen_values : Set[Node] = set()
for node in self._nodes:
if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
if node.graph is not self:
raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
map_arg(node.args, lambda arg: check_arg(arg, node))
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
seen_values.add(node)
if node.name in seen_names:
raise RuntimeError(f'Node redefined name {node.name}!')
seen_names.add(node.name)
# Check targets are legit
if root:
for node in self._nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
target_atoms = node.target.split('.')
m_itr = root
for i, atom in enumerate(target_atoms):
m_itr = getattr(m_itr, atom, None)
if m_itr is None:
seen_qualname = '.'.join(target_atoms[:i])
raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
f'{atom} of {seen_qualname}')
reflectable_magic_methods = {
'add': '{} + {}',
'sub': '{} - {}',
'mul': '{} * {}',
'floordiv': '{} // {}',
'truediv': '{} / {}',
'div': '{} / {}',
'mod': '{} % {}',
'pow': '{} ** {}',
'lshift': '{} << {}',
'rshift': '{} >> {}',
'and': '{} & {}',
'or': '{} | {}',
'xor': '{} ^ {}',
'getitem': '{}[{}]'
}
magic_methods = dict({
'eq': '{} == {}',
'ne': '{} != {}',
'lt': '{} < {}',
'gt': '{} > {}',
'le': '{} <= {}',
'ge': '{} >= {}',
'pos': '+{}',
'neg': '-{}',
'invert': '~{}'}, **reflectable_magic_methods)