blob: 18bfe0773826a58e111465dd4ccf99fd20afe311 [file] [log] [blame]
import inspect
from types import CodeType, FunctionType
from typing import Any, Optional, List, Callable, Union
import torch
from .node import Argument
from .graph import Graph
from .graph_module import GraphModule
from .proxy import TracerBase
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
co = fn.__code__
co_flags = co.co_flags & ~HAS_VARSTUFF
co_args : tuple
if hasattr(co, "co_posonlyargcount"):
co_args = (
nargs, 0,
0, co.co_nlocals, co.co_stacksize,
co_flags, co.co_code, co.co_consts, co.co_names,
co.co_varnames, co.co_filename, co.co_name,
co.co_firstlineno, co.co_lnotab, co.co_freevars,
co.co_cellvars
)
else:
co_args = (
nargs, 0, co.co_nlocals,
co.co_stacksize, co_flags, co.co_code, co.co_consts,
co.co_names, co.co_varnames, co.co_filename,
co.co_name, co.co_firstlineno, co.co_lnotab,
co.co_freevars, co.co_cellvars)
new_code = CodeType(*co_args) # type: ignore
return FunctionType(new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__)
# we need to insert placeholder nodes for *args, and **kwargs,
# so we can't call this function normally, otherwise it would try to unpack them
# instead, let's make python think that args and kwargs are normay variables
class Tracer(TracerBase):
def __init__(self):
super().__init__()
def create_arg(self, a: Any) -> Argument:
# The base tracer is used to construct Graphs when there is no associated
# module hierarchy, so it can never create parameter references.
# The default tracer adds the ability to refer to parameters when
# tracing modules.
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node('get_attr', n, (), {})
raise NameError('parameter is not a member of this module')
# Tensors do not have a reliable string repr() from which they can be
# constructed (and we probably don't want to rely on that, either), so
# for any constant Tensor values we encounter, first search for if they
# are an attribute of some module in the module hierarchy. If so, emit
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_attr.
if isinstance(a, torch.Tensor):
# TODO: slow
def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]:
"""
Search for a tensor value in the module's attributes. If it's
found, return the qualified name of that attribute, given the
previous `qualname_atoms`. If it's not found, recurse down into
child submodules. If it's not found there, return None
"""
for n, p in m.__dict__.items():
if a is p:
return [n]
for n, c in m.named_children():
maybe_result : Optional[List[str]] = search_for_tensor(c)
if maybe_result:
return [n] + maybe_result
return None
# Retrieve the qualname for an existing Tensor attribute
qualname_atoms : Optional[List[str]] = search_for_tensor(self.root)
qualname = '.'.join(qualname_atoms) if qualname_atoms else None
# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
if not qualname:
i = 0
while True:
qualname = f'__tensor_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)
return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
"""
A method to specify whether a given `nn.Module` is a "leaf" module.
Leaf modules are the atomic units that appear in
the IR, referenced by `call_module` calls. By default,
Modules in the PyTorch standard library namespace (torch.nn)
are leaf modules. All other modules are traced through and
their constituent ops are recorded, unless specified otherwise
via this parameter.
Args
m - The module itself
module_qualified_name - The path to root of this module. For example,
if you have a module hierarchy where submodule `foo` contains
submodule `bar`, which contains submodule `baz`, that module will
appear with the qualified name `foo.bar.baz` here.
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
def path_of_module(self, mod):
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
def create_args_for_root(self, root_fn, is_module):
co = root_fn.__code__
total_args = co.co_argcount + co.co_kwonlyargcount
names_iter = iter(co.co_varnames)
args : List[Any] = []
skip_arg_idx = 0
if is_module:
skip_arg_idx = 1
next(names_iter) # skip self
args.append(self.root)
def proxy_placeholder(name: str):
return self.create_proxy('placeholder', name, (), {},
type_expr=root_fn.__annotations__.get(name, None))
args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
# TODO: type annotations for *args and **kwargs
if co.co_flags & inspect.CO_VARARGS:
args.append(proxy_placeholder('*' + next(names_iter)))
if co.co_flags & inspect.CO_VARKEYWORDS:
args.append(proxy_placeholder('**' + next(names_iter)))
root_fn = _patch_function(root_fn, len(args))
return root_fn, args
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
is_module = isinstance(root, torch.nn.Module)
if is_module:
self.root = root
fn = type(root).forward
else:
self.root = torch.nn.Module()
fn = root
self.graph = Graph()
assert isinstance(fn, FunctionType)
fn, args = self.create_args_for_root(fn, is_module)
orig_call = torch.nn.Module.__call__
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
return orig_call(mod, *args, **kwargs)
return self.call_module(mod, forward, args, kwargs)
try:
torch.nn.Module.__call__ = module_call_wrapper
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
type_expr=fn.__annotations__.get('return', None))
finally:
torch.nn.Module.__call__ = orig_call
return self.graph
# Symbolic tracing API
#
# Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`
# constructed by recording operations seen while tracing through `root`.
#
# Args:
# - root - the `nn.Module` instance to trace
def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root))