blob: d1672d332f14bac43fccd397c8194524bb183e2e [file] [log] [blame]
import dis
import torch
import inspect
import operator
from .graph import magic_methods, reflectable_magic_methods, Graph
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
from .node import Target, Node, Argument, base_types
class TracerBase:
graph: Graph
def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
"""
Inserts a graph node given target, args, kwargs, and name.
This method can be overridden to do extra checking, validation, or
modification of values used in node creation. For example, one might
want to disallow in-place operations from being recorded.
"""
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
def proxy(self, node: Node) -> 'Proxy':
return Proxy(node, self)
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None):
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
assert isinstance(kwargs_, dict)
return self.proxy(self.create_node(kind, target, args_, kwargs_, name, type_expr))
def create_arg(self, a: Any) -> Argument:
"""
A method that lowers the objects seen as arguments during symbolic evaluation
into Argument types that can be stored in IR.
Can be override to support more trace-specific types.
"""
# aggregates
if isinstance(a, (tuple, list)):
return type(a)(self.create_arg(elem) for elem in a)
elif isinstance(a, dict):
r = {}
for k, v in a.items():
if not isinstance(k, str):
raise NotImplementedError(f"dictionaries with non-string keys: {a}")
r[k] = self.create_arg(v)
return r
elif isinstance(a, slice):
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
if isinstance(a, Proxy):
# base case: we unwrap the Proxy object
return a.node
elif isinstance(a, base_types) or a is None:
return a
raise NotImplementedError(f"argument of type: {type(a)}")
def to_bool(self, obj: 'Proxy') -> bool:
"""Called when a proxy object is being converted to a boolean, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return a value.
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
def iter(self, obj: 'Proxy') -> Iterator:
"""Called when a proxy object is being iterated over, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return an iterator.
"""
raise TraceError('Proxy object cannot be iterated. '
'This can be attempted when used in a for loop or as a *args or **kwargs function argument.')
def keys(self, obj: 'Proxy') -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an
iterator it ** is suppose to work in your custom tracer.
"""
return Attribute(obj, 'keys')()
# used in Proxy object when just appending to the graph while not tracing.
class GraphAppendingTracer(TracerBase):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
class TraceError(ValueError):
pass
# Proxy objects are stand-in values for normal values in a PyTorch computation.
# Instead of performing compute they record computation into Graph.
# Each proxy wraps the Node instance that represents the expression that define the
# value.
class Proxy:
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
if tracer is None:
# this allows you to create a proxy object around a raw node
# so that if you are doing graph transforms you can use the overloaded operators
# to add additional things to a graph.
tracer = GraphAppendingTracer(node.graph)
self.tracer = tracer
self.node = node
def __repr__(self) -> str:
return f'Proxy({self.node.name})'
def __getattr__(self, k) -> 'Attribute':
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return Attribute(self, k)
def __call__(self, *args, **kwargs) -> 'Proxy':
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
def __iter__(self) -> Iterable['Proxy']:
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2]
if inst.opname == 'UNPACK_SEQUENCE':
return (self[i] for i in range(inst.argval)) # type: ignore
return self.tracer.iter(self)
def __bool__(self) -> bool:
return self.tracer.to_bool(self)
def keys(self):
return self.tracer.keys(self)
def __torch_function__(self, orig_method, types, args=None, kwargs=None):
args = args if args else ()
kwargs = kwargs if kwargs else {}
if torch.overrides.is_tensor_method_or_property(orig_method):
return self.tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
else:
return self.tracer.create_proxy('call_function', orig_method, args, kwargs,
name=self.tracer.graph._name(orig_method.__name__))
class Attribute(Proxy):
def __init__(self, root: Proxy, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
for method in magic_methods:
def scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
return tracer.create_proxy('call_function', target, args, kwargs)
impl.__name__ = method
as_magic = f'__{method}__'
setattr(Proxy, as_magic, impl)
scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
return self.tracer.create_proxy('call_function', target, (rhs, self), {})
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(Proxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)