blob: 402a99946fe4e084758ce1c99f0e28d4d8bf8584 [file] [log] [blame]
import torch.autograd.function as function
import torch._C
from torch.autograd import Variable
from torch.nn import Module, ParameterList, Parameter
from torch._six import raise_from
from collections import defaultdict
import itertools
import types
import contextlib
import os
import functools
import inspect
class Placeholder(object):
def __init__(self, s):
self.s = s
def __str__(self):
return self.s
def __repr__(self):
return self.s
HOLE = Placeholder("HOLE")
VOLATILE = Placeholder("VOLATILE")
# TODO: verify is not implemented yet
def compile(arg=None, verify=False, **kwargs):
"""
Decorator which marks a function or module class as eligible for
just-in-time compilation. The next time the function/module is executed, it
is traced, and the trace is compiled into an optimized representation which
is run in lieu of the original Python code upon subsequent invocations of
the function/module.
.. note::
A JIT compiled function/module may be compiled multiple times, as
different inputs can result in different traces. Currently, the
JIT compiler conservatively assumes the trace may change if the
`size` or `requires_grad` of `Variable` inputs change, or if
any of the non-Variable inputs change. For example, if you JIT
compile an RNN which takes the number of hidden units as a parameter,
we will compile a trace for every RNN length you use at runtime.
When a module class is JIT compiled, each instantiation of the module
gets a separate trace cache.
.. warning::
Just-in-time compilation currently only works for functions/modules
which are not data dependent (e.g., have conditionals on data in
tensors) and do not have any untracked external dependencies (e.g.,
perform input/output or access global variables). If you trace such
models, you will silently get incorrect results on subsequent
invocations of the model. You can use `verify=True` to check that the
original Python code and optimized code are equivalent.
Keyword arguments:
verify (bool, optional): if True, upon all invocations of the
function/module, execute both the compiled and interpreted versions
of the model, and verify that their results match. This is an easy
(albeit slow) way to check if your function/module can be validly
JIT compiled. Default: False.
nderivs (int, optional): the number of derivatives which this function/module
will be used with. You MUST accurately specify this number: set it too
low and you will see an error when you attempt to run `backward`;
set it too high, and the function/module will never be compiled
(as we always wait to see all derivatives before compiling.)
Default: 1 (i.e., we will compile forwards and backwards, but not
double-backwards).
optimize (bool, optional): whether or not to apply optimizations. Default: True.
Debug arguments:
time (bool, optional): if True, whenever we execute the model in question, we
will also print out some timing information for how long the model
took to execute. At the moment, there are three types of timings we
emit:
- unoptimized: the time it took to execute the vanilla Python
model. This only occurs when tracing is disabled, e.g., via
`enabled=False`
- tracing: the time it took to execute the vanilla Python model
with tracing enabled.
- optimized: the time it took to execute the optimized model.
At the moment, all of these timings are for the forward pass only.
Default: False.
enabled (bool, optional): if False, compilation is disabled and you
will get back your original model. This is a convenient way to
disable tracing without having to delete the annotation. Default: True.
Example: Compile as class decorator.
>>> @jit.compile
>>> class MyModel(nn.Module):
>>> ...
>>> model = MyModel()
>>> out1 = model(x) # interpreted run
>>> out1.sum().backward() # won't compile without this line
>>> out2 = model(x) # compiled run
>>> out2.sum().backward() # also compiled
Example: Compile forward pass only as class decorator.
>>> @jit.compile(nderivs=0)
>>> class MyModel(nn.Module):
>>> ...
>>> model = MyModel()
>>> out1 = model(x) # interpreted run
>>> out2 = model(x) # compiled run
Example: Compile as function decorator. The same modes of use for the class
decorator are also supported for functions; however, the decorated
function must declare *all* Variable inputs in its arguments.
>>> @jit.compile
>>> def f(x);
>>> return x * 2
"""
# TODO: handle decorating a class (not an instance)
def _compile(arg):
if inspect.isclass(arg):
if issubclass(arg, _CompiledMixin):
raise TypeError("Cannot compile a model class that already is compiled")
# NB: It might seem natural to create a subclass here, rather than
# make a copy of the class to insert the mixin. Unfortunately, this
# will break many user classes. Suppose you have:
#
# @torch.jit.compile
# class Foo(Module):
# def __init__(self):
# super(Foo, self).__init__() # Python 2 syntax!
#
# within the class definition, 'Foo' refers to the *decorated*
# class, not the undecorated class. This is bad juju if the
# decorator returns a subclass, since super(Foo, self) is going to
# refer to the *undecorated* Foo (and thus you have an infinite
# loop.) Python 3's argument-less super() does not have this
# problem, but in general we cannot ask users to rewrite their code.
#
# If we create a *copy* of the class (unrelated to the class the
# user passed in), this problem goes away, because the class
# __init__ is a part of is indeed Foo.
# Make a copy of the class, with the extra _CompiledMixin base
cls = type(arg.__name__, (_CompiledMixin,) + arg.__bases__, dict(arg.__dict__))
# Monkey-patch forward and __init__ with the compiler versions
cls.init_compiler(**kwargs)
return cls
elif isinstance(arg, Module):
# It requires work to compile module instances, because you would
# like the resulting compiled module to look just like the uncompiled
# version; actually achieving this requires a bit of fanciness.
# So for now, we just only support the class mechanism.
raise TypeError("Compiling model instances is not supported. "
"Use @torch.jit.compile on a class instead.")
elif callable(arg):
@compile(**kwargs)
class FuncModule(Module):
def __init__(self, f):
super(FuncModule, self).__init__()
self.f = f
def forward(self, *args):
return self.f(*args)
return FuncModule(arg)
else:
raise TypeError("Cannot handle arg with type {}".format(type(arg)))
if arg is None:
return _compile
else:
return _compile(arg)
def trace(arg=None, nderivs=0, params=tuple()):
"""
Instrument a function or module for tracing, wrapping it in a
:class:`TracedModule`, whose forward accepts the same arguments as the
original function/module, but returns a tuple consisting of the
*trace* of an execution, as well as the original return value.
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
arg (optional, torch.nn.Module or function): the function or module
to be traced. If `None`, `trace` returns a decorator which can be
applied to the function or module you want to trace.
nderivs (int, default 0): the number of derivatives to trace.
Traces of derivatives are recorded into the same trace returned
after executing the `forward` of the resulting module, but
are not present until you run `backward()` (an appropriate
number of times) on the resulting model.
params (tuple of torch.nn.Parameter): extra parameters for a traced
function, which do not occur as arguments to the function in
question. You generally do not need this for tracing modules, as
the parameters of a module are automatically computed.
Example: Trace as higher order function. (Notice that trace is a *curried*
function; you first apply it with the function/model to trace, and then
apply the result with the arguments.)
>>> traced_model = jit.trace(nn.LSTMCell())
>>> trace, out = traced_model(input, hidden)
Example: Trace the backwards pass as higher order function.
>>> traced_model = jit.trace(nn.LSTMCell(), nderivs=1)
>>> trace, out = traced_model(input, hidden)
>>> out.sum().backward()
>>> print(trace)
"""
# TODO: handle decorating a class (not a callable)
def _trace(inner):
return TracedModule(inner, nderivs=nderivs, params=params)
if callable(arg):
return _trace(arg)
else:
return _trace
# It's OK for TracedModule to look different from the inner module, since
# the forward() return type changed anyway.
class TracedModule(Module):
def __init__(self, inner, params=tuple(), nderivs=0):
super(TracedModule, self).__init__()
# inner may be a Module, or it may be an arbitrary callable
self.inner = inner
self.params = ParameterList(list(params))
self.nderivs = nderivs
def forward(self, *args):
# TODO: Possible optimization: use the unflattened
# output so we don't unflatten it when we get out
# NB: Not a method because trace_func_raw can't deal
# with methods
@_raw_trace(nderivs=self.nderivs)
def traced_inner(in_vars, in_struct):
return _flatten(self.inner(*args))
in_vars, in_struct = _flatten(args, self.state_dict(keep_vars=True).values())
trace, (out_vars, out_struct) = traced_inner(in_vars, in_struct)
out, extra = _unflatten(out_vars, out_struct)
assert len(extra) == 0
return trace, out
# Functional version that assumes that all parameters are explicitly
# specified
def _raw_trace(nderivs=0):
def raw_trace(f):
# f takes two arguments, (in_vars, in_struct) (as determined
# by _flatten); furthermore, it must be the case that in_vars
# contains all Variable inputs (including parameters.) It must
# produce two outputs, (out_vars, out_struct) (also as determined
# by _flatten).
@functools.wraps(f)
def wrapper(in_vars, in_struct=None):
trace = torch._C._tracer_enter(in_vars, nderivs)
out_vars, out_struct = f(in_vars, in_struct)
torch._C._tracer_exit(out_vars)
return trace, (out_vars, out_struct)
return wrapper
return raw_trace
# Lifecycle of a compiler:
#
# - It is given an underlying function, which knows how to actually
# execute the code that we want to compile.
# - When we encounter an input configuration for which we don't
# have an optimized trace, we run the underlying function, tracing its
# result. The trace is not done yet, so we save it into our set of pending
# traces for that configuration.
# - When we encounter an input configuration whose trace is "ready"
# (that is, we've seen all of the passes, so the trace contains
# forwards/backwards/etc), we compile it, and then register this
# as the compiled trace.
# - When we encounter an input configuration whose trace is compiled,
# we just directly run the compiled trace.
#
# You should never use this class directly; instead, use compile. However,
# the intended manual usage of this class looks like this:
#
# class CompiledModel(_CompiledMixin, nn.Module):
# def forward(self, x):
# ...
# CompiledModule.init_compiler()
# model = CompiledModule()
#
class _CompiledMixin(object):
# Global over ALL compilations! This helps us disambig if two Modules have
# the same __name__ but actually are different
__next_id = 0
@classmethod
def init_compiler(cls, params=tuple(), name=None, enabled=True, time=False, **kwargs):
# Ensure we are not shadowing this method on the class we mixed with
assert not hasattr(super(_CompiledMixin, cls), "init_compiler")
# TODO: Consider saving the backtrace of this constructor, so it's easier
# to correlate dump files with invocations in Python
#
# NB: Use private methods/variables here in order to prevent a class
# we mix with from accidentally scrambling us
#
# NB: Class variables are also accessible via self!
cls.__params = ParameterList(list(params))
kwargs["time"] = time # also want to pass onto ktrace
cls.__ktrace_kwargs = kwargs
cls.__enabled = enabled
cls.__time = time
cls.__model_name = name
# Monkey patch the constructor and forward functions *inplace*
cls.__old_forward = cls.forward
cls.forward = cls.__new_forward
cls.__old_init = cls.__init__
cls.__init__ = cls.__new_init
def __new_init(self, *args, **kwargs):
try:
# __old_init is assumed to handle super call
self.__old_init(*args, **kwargs)
except TypeError as e:
# If this fails here, the user probably didn't use this as a class
# decorator
if "super" in str(e):
raise_from(TypeError("torch.jit.compile must be used as a class decorator; "
"using it on an already defined class is not valid."
"\n\nOriginal error: {}".format(str(e))), e)
else:
raise
model_name = self.__model_name if self.__model_name else type(self).__name__
self.__name = "jit_{}_{}".format(model_name, _CompiledMixin.__next_id)
_CompiledMixin.__next_id += 1
self.__ktrace_cache = {}
self.__next_ktrace_id = 0
def __process_args(self, args):
in_vars, in_struct = _flatten(args, self.state_dict(keep_vars=True).values())
is_volatile, in_vars_key = vars_key(in_vars)
in_key = (in_vars_key, in_struct)
return in_vars, in_struct, is_volatile, in_key
# NB: In principle, there could also be a 'raw' version of this compiler,
# but since the logic is so complicated, testing code wouldn't benefit much
def __new_forward(self, *args):
if _JIT_DISABLE or not self.__enabled:
with _time(self.__name, "unoptimized", self.__time):
# Call to the saved old forward function
return self.__old_forward(*args)
in_vars, in_struct, is_volatile, in_key = self.__process_args(args)
ktrace = self.__ktrace_cache.get(in_key)
if ktrace is None:
ktrace_name = '{}_{}'.format(self.__name, self.__next_ktrace_id)
self.__next_ktrace_id += 1
ktrace = TraceForKey(ktrace_name, in_key, volatile=is_volatile, **self.__ktrace_kwargs)
self.__ktrace_cache[in_key] = ktrace
closure = ktrace.maybe_closure()
if closure is not None:
# We already compiled it! Run it directly, and
# use the saved out_struct to unflatten.
with _time(ktrace.name, "optimized", self.__time):
out_vars = closure()(*in_vars)
out_struct = ktrace.out_struct
else:
# No compiled trace available. Run it by hand.
with _time(ktrace.name, "tracing", self.__time):
out_vars, out_struct = ktrace.add_trace(self.__old_forward, args, in_vars, in_struct)
if isinstance(out_vars, Variable):
out_vars = (out_vars, )
out, extras = _unflatten(out_vars, out_struct)
assert len(extras) == 0
return out
def has_trace_for(self, *args):
# Ensure we are not shadowing this method on the class we mixed with
assert not hasattr(super(_CompiledMixin, self), "has_trace_for")
in_vars, in_struct, is_volatile, in_key = self.__process_args(args)
ktrace = self.__ktrace_cache.get(in_key)
if ktrace is None:
return False
return ktrace.maybe_closure() is not None
# TODO: Provide more compiled code management utility methods
# CompiledModule memoizes multiple traces and switches between them based on
# inputs provided to a call; a TraceForKey logically represents one such trace
# (in reality, a TraceForKey may contain multiple traces, but they all share
# the same input configuration and should be equivalent). Things
# that need to be considered include non-Variable argument (e.g. num_layers=3;
# compared by equality) or Variable flags and sizes. TraceForKey is the object
# that is used to hold a trace / compiled code for a single input configuration
# aka in_key.
class TraceForKey(object):
# Lifecycle:
# - We accumulate 'traces'
# - At some point, one of these traces becomes complete ('is_complete'
# is True). This occurs when we run enough backwards on a trace
# to complete it (i.e., this is an external event to this object).
# - Whenever we want to run this trace, we call 'maybe_closure'. This
# returns None if we don't have a complete trace yet, or the
# autograd closure to actually run the trace if we do.
def __init__(self, name, key, nderivs=1, optimize=True, volatile=False, time=False):
self.name = name
self.key = key
# TODO: Not convinced about this volatile special case...
self.nderivs = nderivs if not volatile else 0
self.optimize = optimize
self.traces = []
self.closure = None
self.out_struct = None # initialized when we call trace, checked thereafter
self.time = time
# The signature here is a little goofy; it's a perf optimization.
# Additionally, f is passed in as an argument (even though it is fixed as
# class initialization) to avoid a circular reference.
def add_trace(self, f, args, in_vars, in_struct):
# TODO: Deduplicate this code
@_raw_trace(nderivs=self.nderivs)
def traced_f(in_vars, in_struct):
return _flatten(f(*args))
trace, (out_vars, out_struct) = traced_f(in_vars, in_struct)
if self.out_struct is None:
self.out_struct = out_struct
else:
# TODO: in debug mode, assert the output structs are same
pass
self.traces.append(trace)
return out_vars, out_struct
def maybe_closure(self):
if self.closure is not None:
return self.closure
# GC expired traces
self.traces = [t for t in self.traces if not t.is_expired]
# Search for a complete trace
complete_trace = None
for trace in self.traces:
if trace.is_complete:
complete_trace = trace
self.traces = []
if complete_trace is None:
return None
def _run_pass(p, trace):
pass_name = p.__name__.replace('_jit_pass_', '')
p(trace)
_dump_trace(self.name, pass_name, self.key, trace)
torch._C._jit_pass_lint(trace)
with _time(self.name, "compiling", self.time):
_dump_trace(self.name, "init", self.key, complete_trace)
# It's important to always run DCE, because backward can create a lot of unnecessary nodes
_run_pass(torch._C._jit_pass_dce, complete_trace)
if self.optimize:
_run_pass(torch._C._jit_pass_onnx, complete_trace)
_run_pass(torch._C._jit_pass_fuse, complete_trace)
_dump_trace(self.name, "final", self.key, complete_trace)
self.closure = torch._C._jit_createAutogradClosure(complete_trace)
return self.closure
def vars_key(in_vars):
"""
Compute the key for variables: some properties of variables
affect the trace, e.g., size and requires_grad.
"""
is_volatile = any(x.volatile if isinstance(x, Variable) else False for x in in_vars)
def var_key(x):
if isinstance(x, Variable):
grad_key = x.requires_grad
ty = x.data.type()
else:
grad_key = False
ty = x.type()
if is_volatile:
grad_key = VOLATILE
return ty, grad_key, x.size()
return is_volatile, tuple(map(var_key, in_vars))
@contextlib.contextmanager
def _fork_rng(enabled=True):
"""
Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in. This is important
if we are evaluating a trace twice, and it incorporates
randomness: if we don't reset the seed, we might get totally
different results!
TODO: Randomness in models is a big problem for reproduceability,
because it means if we start executing models out of order,
they may behave differently. Interesting question is whether
or not backwards pass ever has randomness. I hope not.
"""
if not enabled:
yield
return
cpu_rng_state = torch.get_rng_state()
gpu_rng_state = None
if torch.cuda.is_available():
gpu_rng_state = torch.cuda.get_rng_state()
yield
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
torch.cuda.set_rng_state(gpu_rng_state)
# _flatten and _unflatten are inverses
def _unflatten(input, proto):
def unflatten_helper(input, proto):
res = []
if not isinstance(proto, (list, tuple)):
return input[0], input[1:]
for e in proto:
res_e, input = unflatten_helper(input, e)
res.append(res_e)
return type(proto)(res), input
return unflatten_helper(input, proto)
def _flatten(obj, params=tuple()):
obj_vars = tuple(itertools.chain(function._iter_variables(obj), params))
obj_struct = function._nested_map(lambda o: isinstance(o, Variable), lambda x: HOLE)(obj)
return obj_vars, obj_struct
# This is purely for developer debugging. We are not going to advertise it.
_JIT_DUMP = os.environ.get('PYTORCH_JIT_DUMP', False)
_JIT_TIME = os.environ.get('PYTORCH_JIT_TIME', False) # CUDA-only timing
_JIT_DISABLE = os.environ.get('PYTORCH_JIT_DISABLE', False)
def _dump_trace(trace_name, pass_name, input_key, trace):
if not _JIT_DUMP:
return
import torch.contrib._graph_vis as graph_vis
filename = "{}_{}".format(trace_name, pass_name)
# TODO: Also paste out the backtrace when the trace was compiled
# (and maybe also when it was run?)
with open(filename + ".ir", "w") as f:
f.write("Input key: {}\n\n{}".format(input_key, str(trace)))
graph_vis.write(trace.graph(), filename + ".html")
@contextlib.contextmanager
def _time(trace_name, name, time=True):
if (not _JIT_TIME and not time) or not torch.cuda.is_available():
yield
return
stream = torch.cuda.current_stream()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
stream.record_event(start)
try:
yield
finally:
stream.record_event(end)
end.synchronize()
print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")