blob: 692c824775c95fcf3b2d6c86af866ee57dc925bf [file] [log] [blame]
import inspect
from types import CodeType, FunctionType
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
import torch
from torch._C import ScriptObject # type: ignore
from .node import Argument, map_aggregate
from .graph import Graph
from .graph_module import GraphModule
from .proxy import TracerBase, Proxy
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
co = fn.__code__
co_flags = co.co_flags & ~HAS_VARSTUFF
co_args : tuple
if hasattr(co, "co_posonlyargcount"):
co_args = (
nargs, 0,
0, co.co_nlocals, co.co_stacksize,
co_flags, co.co_code, co.co_consts, co.co_names,
co.co_varnames, co.co_filename, co.co_name,
co.co_firstlineno, co.co_lnotab, co.co_freevars,
co.co_cellvars
)
else:
co_args = (
nargs, 0, co.co_nlocals,
co.co_stacksize, co_flags, co.co_code, co.co_consts,
co.co_names, co.co_varnames, co.co_filename,
co.co_name, co.co_firstlineno, co.co_lnotab,
co.co_freevars, co.co_cellvars)
new_code = CodeType(*co_args) # type: ignore
return FunctionType(new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__)
# we need to insert placeholder nodes for *args and **kwargs
# we can't call this function normally, otherwise it would try to unpack them
# instead, let's make python think that args and kwargs are normal variables
class Tracer(TracerBase):
"""
``Tracer`` is the class that implements the symbolic tracing functionality
of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
to ``Tracer().trace(m)``.
Tracer can be subclassed to override various behaviors of the tracing
process. The different behaviors that can be overridden are described
in the docstrings of the methods on this class.
"""
def __init__(self):
super().__init__()
def create_arg(self, a: Any) -> 'Argument':
"""
A method to specify the behavior of tracing when preparing values to
be used as arguments to nodes in the ``Graph``.
By default, the behavior includes:
#. Iterate through collection types (e.g. tuple, list, dict) and recursively
call ``create_args`` on the elements.
#. Given a Proxy object, return a reference to the underlying IR ``Node``
#. Given a non-Proxy Tensor object, emit IR for various cases:
* For a Parameter, emit a ``get_attr`` node referring to that Parameter
* For a non-Parameter Tensor, store the Tensor away in a special
attribute referring to that attribute.
This method can be overridden to support more types.
Args:
a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
Returns:
The value ``a`` converted into the appropriate ``Argument``
"""
# The base tracer is used to construct Graphs when there is no associated
# module hierarchy, so it can never create parameter references.
# The default tracer adds the ability to refer to parameters when
# tracing modules.
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node('get_attr', n, (), {})
raise NameError('parameter is not a member of this module')
elif isinstance(a, torch.Tensor):
for n, p in self.root.named_buffers():
if a is p:
return self.create_node('get_attr', n, (), {})
# For NamedTuple instances that appear literally as args, we emit
# a node to construct the NamedTuple and use that Node as the argument.
if isinstance(a, tuple) and hasattr(a, '_fields'):
args = tuple(self.create_arg(elem) for elem in a)
return self.create_node('call_function', a.__class__, args, {})
# Tensors do not have a reliable string repr() from which they can be
# constructed (and we probably don't want to rely on that, either), so
# for any constant Tensor values we encounter, first search for if they
# are an attribute of some module in the module hierarchy. If so, emit
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_attr.
if isinstance(a, (torch.Tensor, ScriptObject)):
qualname : Optional[str] = self.tensor_attrs.get(a)
# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
if not qualname:
i = 0
while True:
qualname = f'_tensor_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)
return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
"""
A method to specify whether a given ``nn.Module`` is a "leaf" module.
Leaf modules are the atomic units that appear in
the IR, referenced by ``call_module`` calls. By default,
Modules in the PyTorch standard library namespace (torch.nn)
are leaf modules. All other modules are traced through and
their constituent ops are recorded, unless specified otherwise
via this parameter.
Args:
m (Module): The module being queried about
module_qualified_name (str): The path to root of this module. For example,
if you have a module hierarchy where submodule ``foo`` contains
submodule ``bar``, which contains submodule ``baz``, that module will
appear with the qualified name ``foo.bar.baz`` here.
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
def path_of_module(self, mod : torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy
of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
a submodule named ``bar``, passing ``bar`` into this function will return
the string "foo.bar".
Args:
mod (str): The ``Module`` to retrieve the qualified name for.
"""
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
a call to an ``nn.Module`` instance.
By default, the behavior is to check if the called module is a leaf module
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
the operations in its ``forward`` function.
This method can be overridden to--for example--create nested traced
GraphModules, or any other behavior you would want while tracing across
``Module`` boundaries.
``Module`` boundaries.
Args:
m (Module): The module for which a call is being emitted
forward (Callable): The forward() method of the ``Module`` to be invoked
args (Tuple): args of the module callsite
kwargs (Dict): kwargs of the module callsite
Return:
The return value from the Module call. In the case that a ``call_module``
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
def create_args_for_root(self, root_fn, is_module):
"""
Create ``placeholder`` nodes corresponding to the signature of the ``root``
Module. This method introspects root's signature and emits those
nodes accordingly, also supporting ``*args`` and ``**kwargs``.
"""
# In some cases, a function or method has been decorated with a wrapper
# defined via ``functools.wraps``. In this case, the outer code object
# will likely not contain the actual parameters we care about, so unwrap
# the function to get to the innermost callable.
fn_for_analysis = inspect.unwrap(root_fn)
co = fn_for_analysis.__code__
total_args = co.co_argcount + co.co_kwonlyargcount
names_iter = iter(co.co_varnames)
args : List[Any] = []
skip_arg_idx = 0
if is_module:
if total_args == 0:
raise RuntimeError('``self`` argument cannot be part of *args expansion!')
skip_arg_idx = 1
next(names_iter) # skip self
args.append(self.root)
sig = inspect.signature(fn_for_analysis)
def proxy_placeholder(name: str):
if name[0] == '*':
default = () # type: ignore
else:
param = sig.parameters[name]
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore
return self.create_proxy('placeholder', name, default, {},
type_expr=fn_for_analysis.__annotations__.get(name, None))
args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
# TODO: type annotations for *args and **kwargs
if co.co_flags & inspect.CO_VARARGS:
args.append(proxy_placeholder('*' + next(names_iter)))
if co.co_flags & inspect.CO_VARKEYWORDS:
args.append(proxy_placeholder('**' + next(names_iter)))
root_fn = _patch_function(root_fn, len(args))
return root_fn, args
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Args:
root (Union[Module, Callable]): Either a ``Module`` or a function to be
traced through.
Returns:
A ``Graph`` representing the semantics of the passed-in ``root``.
"""
if isinstance(root, torch.nn.Module):
self.root = root
fn = type(root).forward
else:
self.root = torch.nn.Module()
fn = root
self.graph = Graph()
# When we encounter a Tensor value that's not a parameter, we look if it
# is some other attribute on the model. Construct a dict mapping Tensor
# values to the qualified name here for efficiency. This is used downstream
# in create_arg
self.tensor_attrs : Dict[torch.Tensor, str] = {}
def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]):
for k, v in m.__dict__.items():
if isinstance(v, (torch.Tensor, ScriptObject)):
self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
for k, v in m.named_children():
collect_tensor_attrs(v, prefix_atoms + [k])
collect_tensor_attrs(self.root, [])
assert isinstance(fn, FunctionType)
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module))
orig_call = torch.nn.Module.__call__
orig_getattr = torch.nn.Module.__getattr__
parameter_proxy_cache = {} # Reduce number of get_attr calls
# Method dispatch on parameters is not recorded unless it's directly used.
# Thus, we need to insert a proxy when __getattr__ requests a parameter.
def module_getattr_wrapper(mod, attr):
attr_val = orig_getattr(mod, attr)
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy('get_attr', n, (), {})
return parameter_proxy_cache[n]
return attr_val
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
return orig_call(mod, *args, **kwargs)
return self.call_module(mod, forward, args, kwargs)
try:
# Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427
torch.nn.Module.__getattr__ = module_getattr_wrapper # type: ignore
torch.nn.Module.__call__ = module_call_wrapper
orig_fns = _patch_wrapped_functions()
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
type_expr=fn.__annotations__.get('return', None))
finally:
_unpatch_wrapped_functions(orig_fns)
torch.nn.Module.__call__ = orig_call
torch.nn.Module.__getattr__ = orig_getattr # type: ignore
return self.graph
# List of pairs of (global dict, function name) functions
# to patch for the purposes of the wrap() API.
_wrapped_fns_to_patch : List[Tuple[dict, str]] = []
def _create_wrapped_func(orig_fn):
def wrapped(*args, **kwargs):
"""
Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
a Proxy object. If there is one, emit a ``call_function`` node to preserve the
call to this leaf function directly. Otherwise, just return the results of
this function call, as this function is not being traced.
"""
proxy = None
def find_proxy(x):
nonlocal proxy
if isinstance(x, Proxy):
proxy = x
map_aggregate(args, find_proxy)
map_aggregate(kwargs, find_proxy)
if proxy is not None:
return proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs)
else:
return orig_fn(*args, **kwargs)
return wrapped
class PatchedFn(NamedTuple):
frame_dict : Dict[str, Any]
fn_name : str
orig_fn : Any
def _patch_wrapped_functions() -> List[PatchedFn]:
"""
Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
the listed global functions in the `_create_wrapped_func` wrapper. Returns
a list of PatchedFn, which is a record specifiying a single function
entry that was patched and contains the original function for unpatching
"""
orig_fns : List[PatchedFn] = []
# Set to deduplicate entries. Wrapping a function multiple times would
# be an error, since it would cause a `call_function` node for the
# wrapper to be emitted rather than the actual underlying function
#
# Use id(frame_dict) as a hashable identity here since none of the
# frame dicts should be destroyed during symtracing
processed_entries : Set[Tuple[int, str]] = set()
for frame_dict, name in _wrapped_fns_to_patch:
if (id(frame_dict), name) in processed_entries:
continue
orig_fn = frame_dict[name]
orig_fns.append(PatchedFn(frame_dict, name, orig_fn))
frame_dict[name] = _create_wrapped_func(orig_fn)
processed_entries.add((id(frame_dict), name))
return orig_fns
def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]):
"""
Given the ``orig_fns`` dict that ``_patch_wrapped_functions``,
replace all of the global functions with the original global functions
that were there before symbolic tracing.
"""
for frame_dict, fn_name, orig_fn in orig_fns:
frame_dict[fn_name] = orig_fn
def wrap(fn_or_name : Union[str, Callable]):
"""
This function can be called at global scope in a module to cause
references to the global function secified by `fn_name` to use
them in FX.
# foo/bar/baz.py
def my_custom_function(x, y):
return x * x + y * y
torch.fx.wrap('my_custom_function')
def fn_to_be_traced(x, y):
# When symbolic tracing, the below call to my_custom_function will be inserted into
# the graph rather than tracing it.
return my_custom_function(x, y)
Args:
fn_name (Union[str, Callable]): The function or name of the global function to insert into the
graph when it's called
"""
if callable(fn_or_name):
fn_name = fn_or_name.__code__.co_name
elif isinstance(fn_or_name, str):
fn_name = fn_or_name
else:
raise RuntimeError('Unsupported type for global function! Must be either a callable or '
'string name')
if hasattr(fn_or_name, '__code__'):
assert not isinstance(fn_or_name, str) # to make mypy happy
fn_name = fn_or_name.__code__.co_name
else:
assert isinstance(fn_or_name, str), "fn_or_name must be a global function or string name"
fn_name = fn_or_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != '<module>':
raise NotImplementedError('wrap must be called at the top level of a module')
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
"""Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
Args:
root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
into a Graph representation.
Returns:
GraphModule: a Module created from the recorded operations from ``root``.
"""
return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root))