blob: 4a7ab8cebbc85ead502a445345d31c7e926e7dfb [file] [log] [blame]
import torch
import torch.overrides
import linecache
from typing import Type, Dict, List, Any, Union
from .graph import Graph
import copy
import sys
import traceback
import math
# normal exec loses the source code, however we can patch
# the linecache module to still recover it.
# using exec_with_source will add it to our local cache
# and then tools like TorchScript will be able to get source info.
_next_id = 0
def exec_with_source(src: str, globals: Dict[str, Any]):
global _next_id
key = f'<eval_with_key_{_next_id}>'
_next_id += 1
_eval_cache[key] = [line + '\n' for line in src.splitlines()]
exec(compile(src, key, 'exec'), globals)
# patch linecache so that any code we exec using exec_with_source
# works with inspect
_eval_cache : Dict[str, List[str]] = {}
_orig_getlines = linecache.getlines
def patched_getline(*args, **kwargs):
if args[0] in _eval_cache:
return _eval_cache[args[0]]
return _orig_getlines(*args, **kwargs)
linecache.getlines = patched_getline
def _forward_from_src(src : str):
gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan}
exec_with_source(src, gbls)
return gbls['forward']
def deserialize_graphmodule(body : dict) -> torch.nn.Module:
"""
Deserialize a GraphModule given the dictionary of the original module,
using the code to reconstruct the graph. We delete the actual graph before
saving the dictionary so that changes to the in-memory graph format do not
get serialized.
"""
# We create a dummy class here because symbolic_trace pulls the forward()
# function off of the class, rather than the instance
class CodeOnlyModule(torch.nn.Module):
def __init__(self, body):
super().__init__()
self.__dict__ = body
CodeOnlyModule.forward = _forward_from_src(body['code'])
from .symbolic_trace import Tracer
# we shouldn't trace into any of the submodules, they were not
# because they were not traced in the original GraphModule
class KeepModules(Tracer):
def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
return True
com = CodeOnlyModule(body)
return GraphModule(com, KeepModules().trace(com))
# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
# This installs empty Modules where none exist yet if they are subpaths of target
def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
*prefix, field = target.split('.')
for item in prefix:
f = getattr(from_module, item)
t = getattr(to_module, item, None)
if f is t:
# we have already installed one of its parents
# (e.g. target = root.linear.weight, but we have already installed root.linear)
# once we install a parent, we no longer need to copy the children
# since all the needed properties will already be present
return
if t is None:
t = torch.nn.Module()
setattr(to_module, item, t)
from_module, to_module = f, t
setattr(to_module, field, getattr(from_module, field))
# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
# This installs empty Modules where none exist yet if they are subpaths of target
def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
*prefix, field = target.split('.')
for item in prefix:
t = getattr(to_module, item, None)
if t is None:
t = torch.nn.Module()
setattr(to_module, item, t)
to_module = t
setattr(to_module, field, from_obj)
class GraphModule(torch.nn.Module):
"""
GraphModule is an nn.Module generated from an fx.Graph. GraphModule has
important attributes:
graph : The graph from which this GraphModule was generated
code : The Python source code for the function generated from `graph`
forward : The Python method generated from `graph`
Note that when `graph` is reassigned, `code` and `forward` will be automatically
regenerated. However, if you edit the contents of the `graph` without reassigning
the `graph` attribute itself, you must call `recompile()` to update the generated
code.
"""
def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
# each instance of a graph module needs its own forward method
# so create a new singleton class for each instance.
# it is a subclass of the user-defined class, the only difference
# is an extra layer to install the forward method
class GraphModuleImpl(cls): # type: ignore
pass
return super().__new__(GraphModuleImpl)
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
"""
Construct a GraphModule.
root - `root` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
- In the case that `root` is a Module, any references to Module-based objects (via qualified
name) in the Graph's Nodes' `target` field will be copied over from the respective place
within `root`'s Module hierarchy into the GraphModule's module hierarchy.
- In the case that `root` is a dict, the qualified name found in a Node's `target` will be
looked up directly in the dict's keys. The object mapped to by the Dict will be copied
over into the appropriate place within the GraphModule's module hierarchy.
graph - `graph` contains the nodes this GraphModule should use for code generation
"""
super().__init__()
if isinstance(root, torch.nn.Module):
if hasattr(root, 'training'):
self.training = root.training
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
elif isinstance(root, dict):
targets_to_copy = []
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
if node.target not in root:
raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
' but that target was not provided in `root`!')
targets_to_copy.append(node.target)
# Sort targets in ascending order of the # of atoms.
# This will ensure that less deeply nested attributes are assigned
# before more deeply nested attributes. For example, foo.bar
# will be assigned before foo.bar.baz. Otherwise, we might assign
# the user-provided `foo.bar` and wipe out the previously-assigned
# `foo.bar.baz`
targets_to_copy.sort(key=lambda t: t.count('.'))
for target_to_copy in targets_to_copy:
_assign_attr(root[target_to_copy], self, target_to_copy)
else:
raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
self.graph = graph
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
__jit_unused_properties__ = ['graph']
@property
def graph(self):
return self._graph
@graph.setter
def graph(self, val) -> None:
self._graph = val
self.recompile()
def recompile(self) -> None:
"""
Recompile this GraphModule from its `graph` attribute. This should be
called after editing the contained `graph`, otherwise the generated
code of this `GraphModule` will be out of date.
"""
self.code = self._graph.python_code(root_module='self')
cls = type(self)
cls.forward = _forward_from_src(self.code)
cls_call = cls.__call__
def print_full_traceback(exctype, value, tb):
traceback.print_exception(exctype, value, tb)
def wrapped_call(self, *args, **kwargs):
old_excepthook = sys.excepthook
try:
sys.excepthook = print_full_traceback
return cls_call(self, *args, **kwargs)
finally:
sys.excepthook = old_excepthook
cls.__call__ = wrapped_call
def __reduce__(self):
dict_without_graph = self.__dict__.copy()
del dict_without_graph['_graph']
return (deserialize_graphmodule, (dict_without_graph,))
# because __reduce__ is defined for serialization,
# we need to define deepcopy otherwise it will call __reduce__
# and cause symbolic tracing to occur every time we try to copy the object
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return GraphModule(fake_mod, self.graph)
def __copy__(self):
return GraphModule(self, self.graph)
def __str__(self) -> str:
orig_str = super().__str__()
return '\n'.join([orig_str, self.code])
# workarounds for issues in __torch_function__
# WAR for __torch_function__ not handling tensor lists,
# fix is in https://github.com/pytorch/pytorch/pull/34725
# orig_cat = torch.cat
# def patched_cat(*args, **kwargs):
# tensors = args[0]
# for t in tensors:
# if isinstance(t, Proxy):
# return t.__torch_function__(patched_cat, (), args, kwargs)
# return orig_cat(*args, **kwargs)
# patched_cat.__module__ = 'torch'
# patched_cat.__name__ = 'cat'
# torch.cat = patched_cat