blob: c92e62230df2c62bff8df897428adaf240518206 [file] [log] [blame]
import torch._C
from torch import Tensor
from torch.autograd import Variable, function
from torch.nn import Module, ModuleList, ParameterList, Parameter
from torch.jit.frontend import get_jit_ast
from torch._six import raise_from, with_metaclass
from collections import defaultdict, OrderedDict, namedtuple
import sys
import warnings
import itertools
import weakref
import types
import contextlib
import os
import functools
import inspect
import copy
import numbers
import collections
import re
_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten
_jit_script_compile = torch._C._jit_script_compile
# This global variable is set when we are tracing a *forwards* computation.
# It is intended to be a cheap way to test if tracing has occurred, before
# doing the slower path using `get_tracing_state` (below.)
_tracing = False
def get_tracing_state(args):
if not torch._C._is_tracing(args):
return None
return torch._C._get_tracing_state(args)
@contextlib.contextmanager
def scope(scope_name, *vars):
tracing_state = get_tracing_state(vars)
if tracing_state:
tracing_state.push_scope(scope_name)
try:
yield
finally:
if tracing_state:
tracing_state.pop_scope()
def compile(arg=None, nderivs=1, optimize=True, enabled=True):
"""
Decorator which marks a function or module class as eligible for
just-in-time compilation. The next time the function/module is executed, it
is traced, and the trace is compiled into an optimized representation which
is run in lieu of the original Python code upon subsequent invocations of
the function/module.
.. note::
A JIT compiled function/module may be compiled multiple times, as
different inputs can result in different traces. Currently, the
JIT compiler conservatively assumes the trace may change if the
`size` or `requires_grad` of `Tensor` inputs change, or if
any of the non-Tensor inputs change. For example, if you JIT
compile an RNN which takes the number of hidden units as a parameter,
we will compile a trace for every RNN length you use at runtime.
When a module class is JIT compiled, each instantiation of the module
gets a separate trace cache.
.. warning::
Just-in-time compilation currently only works for functions/modules
which are not data dependent (e.g., have conditionals on data in
tensors) and do not have any untracked external dependencies (e.g.,
perform input/output or access global variables). If you trace such
models, you will silently get incorrect results on subsequent
invocations of the model.
Keyword arguments:
nderivs (int, optional): the number of derivatives which this function/module
will be used with. You MUST accurately specify this number: set it too
low and you will see an error when you attempt to run `backward`;
set it too high, and the function/module will never be compiled
(as we always wait to see all derivatives before compiling.)
Default: 1 (i.e., we will compile forwards and backwards, but not
double-backwards).
optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
Debug arguments:
time (bool, optional): if ``True``, whenever we execute the model in question, we
will also print out some timing information for how long the model
took to execute. At the moment, there are three types of timings we
emit:
- unoptimized: the time it took to execute the vanilla Python
model. This only occurs when tracing is disabled, e.g., via
`enabled=False`
- tracing: the time it took to execute the vanilla Python model
with tracing enabled.
- optimized: the time it took to execute the optimized model.
At the moment, all of these timings are for the forward pass only.
Default: ``False``.
enabled (bool, optional): if ``False``, compilation is disabled and you
will get back your original model. This is a convenient way to
disable tracing without having to delete the annotation. Default: ``True``.
Example: Compile as class decorator.
>>> @jit.compile
>>> class MyModel(nn.Module):
>>> ...
>>> model = MyModel()
>>> out1 = model(x) # interpreted run
>>> out1.sum().backward() # won't compile without this line
>>> out2 = model(x) # compiled run
>>> out2.sum().backward() # also compiled
Example: Compile forward pass only as class decorator.
>>> @jit.compile(nderivs=0)
>>> class MyModel(nn.Module):
>>> ...
>>> model = MyModel()
>>> out1 = model(x) # interpreted run
>>> out2 = model(x) # compiled run
Example: Compile as function decorator. The same modes of use for the class
decorator are also supported for functions; however, the decorated
function must declare *all* Tensor inputs in its arguments.
>>> @jit.compile
>>> def f(x):
>>> return x * 2
"""
def _compile(arg):
if inspect.isclass(arg):
# NB: It might seem natural to create a subclass here, rather than
# make a copy of the class to insert the mixin. Unfortunately, this
# will break many user classes. Suppose you have:
#
# @torch.jit.compile
# class Foo(Module):
# def __init__(self):
# super(Foo, self).__init__() # Python 2 syntax!
#
# within the class definition, 'Foo' refers to the *decorated*
# class, not the undecorated class. This is bad juju if the
# decorator returns a subclass, since super(Foo, self) is going to
# refer to the *undecorated* Foo (and thus you have an infinite
# loop.) Python 3's argument-less super() does not have this
# problem, but in general we cannot ask users to rewrite their code.
#
# If we create a *copy* of the class (unrelated to the class the
# user passed in), this problem goes away, because the class
# __init__ is a part of is indeed Foo.
old_init = arg.__init__
# Python 2 has a concept of unbound methods, which are returned when
# you take a method form a class. They behave just like regular functions,
# but check the type of the first argument (self). We don't want this here,
# because self in our __init__ will be an instance of this new class.
# Python 3 already returns a plain function, so nothing has to be done.
if sys.version_info[0] == 2:
old_init = old_init.im_func
def __init__(self, *args, **kwargs):
torch._C.CompiledFunction.__init__(self,
nderivs, optimize, enabled,
self.forward,
arg.__name__)
try:
old_init(self, *args, **kwargs)
except TypeError as e:
# If this fails here, the user probably didn't use this as a class decorator
if "super" in str(e):
raise_from(TypeError("torch.jit.compile must be used as a class decorator; "
"using it on an already defined class is not valid."
"\n\nOriginal error: {}".format(str(e))), e)
else:
raise
# NOTE: This can't be done in CompiledFunction constructor,
# because self.parameters() isn't well defined by then
# (Module constructor hasn't run yet).
self.set_captured_vars(list(self.parameters()))
new_dict = dict(arg.__dict__)
new_dict['__init__'] = __init__
new_dict['__call__'] = torch._C.CompiledFunction.__call__
# NOTE: we don't need to override casting methods, because we only capture
# parameters, and they mutate their data in-place.
return type(arg.__name__,
arg.__bases__ + (torch._C.CompiledFunction,),
new_dict)
elif isinstance(arg, Module):
# It requires work to compile module instances, because you would
# like the resulting compiled module to look just like the uncompiled
# version; actually achieving this requires a bit of fanciness.
# So for now, we just only support the class mechanism.
raise TypeError("Compiling model instances is not supported. "
"Use @torch.jit.compile on a class instead.")
elif callable(arg):
compiled_fn = torch._C.CompiledFunction(nderivs, optimize, enabled,
arg, arg.__name__)
return compiled_fn
else:
raise TypeError("Cannot handle arg with type {}".format(type(arg)))
# Make empty parenthesis optional
if arg is None:
return _compile
else:
return _compile(arg)
def get_trace_graph(f, args=tuple(), kwargs=None, nderivs=0):
"""
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value.
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
f (torch.nn.Module or function): the function or module
to be traced.
args (tuple or Tensor): the positional arguments to pass to the
function/module to be traced. A non-tuple is assumed to
be a single positional argument to be passed to the model.
kwargs (dict): the keyword arguments to pass to the function/module
to be traced.
nderivs (int, default 0): the number of derivatives to trace.
Traces of derivatives are recorded into the same trace returned
after executing the `forward` of the resulting module, but
are not present until you run `backward()` (an appropriate
number of times) on the resulting model.
Example: Trace the forwards pass only.
>>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden))
>>> print(trace)
Example: Trace the backwards pass too.
>>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden), nderivs=1)
>>> out.sum().backward()
>>> print(trace)
"""
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
def _unique_state_dict(module, keep_vars=False):
state_dict = module.state_dict(keep_vars=keep_vars)
filtered_dict = type(state_dict)()
seen_ids = set()
for k, v in state_dict.items():
if id(v) in seen_ids:
continue
seen_ids.add(id(v))
filtered_dict[k] = v
return filtered_dict
class LegacyTracedModule(Module):
def __init__(self, inner, nderivs=0):
super(LegacyTracedModule, self).__init__()
# inner may be a Module, or it may be an arbitrary callable
# If it's a Module, we get its parameters automatically, which lets
# us avoid a special casing functions versus modules.
self.inner = inner
self.nderivs = nderivs
def forward(self, *args):
global _tracing
in_vars, in_desc = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(_unique_state_dict(self, keep_vars=True).values())
trace, all_trace_inputs = torch._C._tracer_enter(in_vars + module_state, self.nderivs)
_tracing = True
trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
out = self.inner(*trace_inputs)
out_vars, _ = _flatten(out)
_tracing = False
torch._C._tracer_exit(out_vars)
return trace, out
def _clone_inputs(args):
def clone_input(a):
if a is None:
return None
elif isinstance(a, torch.Tensor):
# TODO: figure out one liner to .clone() and set requires_grad
v = Variable(a.data.clone(), requires_grad=a.requires_grad)
if a.grad is not None:
v.grad = clone_input(v.grad)
return v
else:
return a.clone()
return function._nested_map(lambda x: isinstance(x, torch.Tensor),
clone_input, condition_msg="tensors")(args)
# This is purely for developer debugging. We are not going to advertise it.
_JIT_DUMP = os.environ.get('PYTORCH_JIT_DUMP', False)
_JIT_TIME = os.environ.get('PYTORCH_JIT_TIME', False) # CUDA-only timing
_JIT_DISABLE = os.environ.get('PYTORCH_JIT_DISABLE', False)
_JIT_STATS = os.environ.get('PYTORCH_JIT_STATS', False)
def _dump_trace(trace_name, pass_name, input_key, trace):
if not _JIT_DUMP:
return
import torch.contrib._graph_vis as graph_vis
filename = "{}_{}".format(trace_name, pass_name)
# TODO: Also paste out the backtrace when the trace was compiled
# (and maybe also when it was run?)
with open(filename + ".ir", "w") as f:
f.write("Input key: {}\n\n{}".format(input_key, str(trace)))
graph_vis.write(trace.graph(), filename + ".html")
@contextlib.contextmanager
def _time(trace_name, name, time=True):
if (not _JIT_TIME and not time) or not torch.cuda.is_available():
yield
return
stream = torch.cuda.current_stream()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
stream.record_event(start)
try:
yield
finally:
stream.record_event(end)
end.synchronize()
print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
def verify(model, args, loss_fn=torch.sum, devices=None):
"""
Verify that a JIT compiled model has the same behavior as its uncompiled
version along with its backwards pass. If your model returns multiple
outputs, you must also specify a `loss_fn` to produce a loss for which
the backwards will be computed.
This function has side-effects (e.g., it executes your model / saves and loads
parameters), so don't expect the model to come out exactly the same as what
you passed in.
Arguments:
model (compiled torch.nn.Module or function): the module/function to be
verified. The module/function definition MUST have been decorated with
`@torch.jit.compile`.
args (tuple or Tensor): the positional arguments to pass to the
compiled function/module to be verified. A non-tuple is assumed to
be a single positional argument to be passed to the model.
loss_fn (function, optional): the loss function to be applied to
the output of the model, before backwards is invoked. By default,
we assume that a model returns a single result, and we :func:`torch.sum`
before calling backwards; if this is inappropriate, you can pass your
own loss function. Note that if a model returns a tuple of results,
these are passed as separate positional arguments to `loss_fn`.
devices (iterable of device IDs, optional): the GPU devices which the
compiled module will be run on. This determines the RNG state we
must save when running both compiled and uncompiled versions of the model.
"""
# TODO: In principle, we track device information in our trace, so it
# should be possible to check if our execution actually obeyed the 'devices'
# the user provided.
# TODO: Consider adding a utility function to torch.jit to test
# for this case
if not isinstance(model, torch._C.CompiledFunction):
raise TypeError("Cannot verify an uncompiled module. Add @torch.jit.compile to compile it")
is_module = isinstance(model, Module)
if not isinstance(args, tuple):
args = (args,)
saved_args = _clone_inputs(args)
if is_module:
saved_state = copy.deepcopy(model.state_dict())
def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
params = list(model.parameters()) if is_module else []
in_vars, _ = _flatten((args, params))
# We use a special API to reset the trace and compile it from scratch.
compiled_fn = model
if force_trace:
compiled_fn.clear_cache()
if assert_compiled:
hits = compiled_fn.hits
out = model(*args)
if assert_compiled and compiled_fn.hits == hits:
raise RuntimeError("failed to use the compiled function")
if not isinstance(out, tuple):
out = (out, )
if loss_fn == torch.sum and len(out) != 1:
raise ValueError(("Model returns {} outputs, but default loss function "
"(torch.sum) can only handle a single output").format(len(out)))
out_vars, _ = _flatten(out)
saved_outs = [v.data.clone() for v in out_vars]
loss = loss_fn(*out)
grads = torch.autograd.grad([loss], in_vars)
# TODO: I'm not sure if the clone here is necessary but it is safer
saved_grads = [v.data.clone() for v in grads]
return (saved_outs, saved_grads)
with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
assert model.has_trace_for(*args)
if is_module:
model.load_state_dict(saved_state)
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
_verify_equal(uncompiled_outs, compiled_outs)
_verify_equal(uncompiled_grads, compiled_grads)
def _verify_equal(xs, ys):
for x, y in zip(xs, ys):
if x.sub(y).abs().max() > 1e-6:
raise RuntimeError("JIT and real computation mismatch")
def trace(*args, **kwargs):
"""
Trace a function and return an executable trace that will be optimized
using just-in-time compilation.
.. warning::
Just-in-time compilation currently only works for functions/modules
which are not data dependent (e.g., have conditionals on data in
tensors) and do not have any untracked external dependencies (e.g.,
perform input/output or access global variables). If you trace such
models, you will silently get incorrect results on subsequent
invocations of the model.
Arg:
*args - a list of example tensors that will be passed to the function
as inputs while tracing. The resulting trace can be run with
inputs of different types and shapes assuming the traced operations
support those types and shapes.
Keyword arguments:
optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
>>> @jit.trace(torch.rand(1))
... def f(x):
... return x * 2
"""
def wrapper(func):
executor_options = {'optimize': True}
for name in executor_options:
executor_options[name] = kwargs.pop(name, executor_options[name])
if len(kwargs) != 0:
raise TypeError("got unexpected keyword arguments: {}".format(", ".join(kwargs.keys())))
if isinstance(func, torch.nn.Module):
module = TopLevelTracedModule(func, **executor_options)
module._create_method_from_trace('forward', func, args)
return module
else:
return torch._C.GraphExecutor(func, args, **executor_options)
return wrapper
def createResolutionCallback(frame_id=2):
"""
Creates a function which, given a string variable name,
returns the value of the variable in the scope of the caller of
the function which called createResolutionCallback (by default).
For example, the following program prints 2::
def bar():
cb = createResolutionCallback()
print(x("foo"))
def baz():
foo = 2
bar()
baz()
This is used to enable access in-scope Python variables inside
TorchScript fragments.
"""
frame = inspect.stack()[frame_id][0]
def env(key):
if key in frame.f_locals:
return frame.f_locals[key]
elif key in frame.f_globals:
return frame.f_globals[key]
else:
return None
return env
class CompilationUnit(object):
def __init__(self, lang=None, optimize=True):
self.module = torch._C.ScriptModule()
self.module._set_optimized(optimize)
if lang is not None:
self.define(lang, frame_id=3)
self.optimize = optimize
def define(self, lang, rcb=None, frame_id=2):
if not rcb:
rcb = createResolutionCallback(frame_id)
self.module._define(lang, rcb, False)
def __getattr__(self, attr):
return self.module._get_method(attr)
def _script_graph(fn, frame_id=2):
rcb = createResolutionCallback(frame_id)
ast = get_jit_ast(fn)
return _jit_script_compile(ast, rcb)
def script(fn):
graph = _script_graph(fn, frame_id=3)
return torch._C.GraphExecutor(graph, True)
ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'ast'))
def script_method(fn):
return ScriptMethodStub(createResolutionCallback(), get_jit_ast(fn))
# These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of script::Module.
# This allows us to reuse most of nn.Module while still storing the
# data in C++.
# Each OrderedDict needs to support:
# x not in view
# x in view
# view[name] = ...
# view.values()
# del view[name]
# view.items()
# view.keys()
# len(view)
class OrderedDictWrapper(object):
def __init__(self, module):
self.module_ref = weakref.ref(module)
@property
def module(self):
r = self.module_ref()
if r is None:
raise RuntimeError("_parameters or _modules alive after module is dead")
return r
def keys(self):
return [k for k, v in self.items()]
def values(self):
return [v for k, v in self.items()]
def __delitem__(self, k):
raise RuntimeError("cannot delete methods or parameters of a script module")
def items(self):
raise NotImplementedError
def __contains__(self, k):
raise NotImplementedError
def __getitem__(self, k):
raise NotImplementedError
def __setitem__(self, k, v):
raise NotImplementedError
class OrderedModuleDict(OrderedDictWrapper):
def __init__(self, module):
super(OrderedModuleDict, self).__init__(module)
# contains _both_ script modules and non-script python-only modules
# because script modules are subclassed in python and the
# C++ script::Module class will not hold references to them,
# to ensure that you always get the same python value here
# we store it in the python dict as well
self._python_modules = OrderedDict()
def items(self):
r = self._python_modules.items()
return r
def __contains__(self, k):
return k in self._python_modules
def __setitem__(self, k, v):
if k in self._python_modules:
raise RuntimeError("cannot re-assign modules in a ScriptModule")
if isinstance(v, ScriptModule):
self.module._register_module(k, v)
self._python_modules[k] = v
def __getitem__(self, k):
return self._python_modules[k]
class OrderedParameterDict(OrderedDictWrapper):
def __init__(self, module):
super(OrderedParameterDict, self).__init__(module)
def items(self):
return [(name, param) for name, param, is_buffer
in self.module._get_parameters()
if not is_buffer]
def __setitem__(self, k, v):
self.module._register_parameter(k, v, False)
def __contains__(self, k):
return self.module._has_parameter(k)
def __getitem__(self, k):
if k not in self:
raise KeyError(k)
return self.module._get_parameter(k)
class OrderedBufferDict(OrderedDictWrapper):
def __init__(self, module):
super(OrderedBufferDict, self).__init__(module)
def items(self):
return [(name, param) for name, param, is_buffer
in self.module._get_parameters()
if is_buffer]
def __setitem__(self, k, v):
self.module._register_parameter(k, v, True)
def __contains__(self, k):
return self.module._has_buffer(k)
def __getitem__(self, k):
if k not in self:
raise KeyError(k)
return self.module._get_parameter(k)
# base types that can be constants
# in addition, tuples and lists of these base types are also considered constants
# If you edit this list, then you also need to edit the handlers in
# ConstantValue in jit/script/init.cpp
_constant_types = (bool, float, int, types.FunctionType)
def _get_valid_constant(v):
if isinstance(v, _constant_types):
return v
elif isinstance(v, tuple) or isinstance(v, list):
return tuple(_get_valid_constant(x) for x in v)
constants = ", ".join(typ.__name__ for typ in _constant_types)
raise TypeError(
"'{}' object is not a valid constant.\n".format(type(v).__name__) +
"Valid constants are:\n" +
" 1. a nn.ModuleList\n" +
" 2. a value of type {{{}}}\n".format(constants) +
" 3. a list or tuple of (2)\n")
# For each user-defined class that subclasses ScriptModule this meta-class,
# (1) finds all the methods annotated with @script_method
# in a ScriptModule and removes them from the class attributes, and
# (2) puts a wrapper around the class's __init__ method to register
# all of the script_methods with the module after the original __init__
# has run. This has to occur after the user-defined __init__ so that
# submodules and parameters are initialized _before_ the script compiler
# resolve references to `self.param` or `self.module`.
class ScriptMeta(type(torch._C.ScriptModule)):
# this has to inherit from pybind11's metaclass otherwise we get
# issues because ScriptModule inherits from torch._C.ScriptModule,
# a pybind11 type
def __init__(cls, name, bases, attrs):
# find all the script methods
methods = []
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
methods.append(v)
# after the user's __init__ register all the script methods
# with the module
original_init = getattr(cls, '__init__', lambda self: None)
super_constants = getattr(super(cls), '_constants_set', set())
cls._constants_set = set(getattr(cls, '__constants__', ())).union(super_constants)
def init_then_register(self, *args, **kwargs):
# ensure even if the user forgets to call super that
# the pybind object is initialized so it will not segfault
# run this once, before the most-derived __init__ is called
if cls is type(self):
torch._C.ScriptModule.__init__(self)
original_init(self, *args, **kwargs)
asts = [m.ast for m in methods]
rcbs = [m.resolution_callback for m in methods]
self._create_methods(asts, rcbs)
cls.__init__ = init_then_register
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class ScriptModule(with_metaclass(ScriptMeta, Module, torch._C.ScriptModule)):
def __init__(self, optimize=True):
# must be before Module.init since the field is used in __getattr__
Module.__init__(self)
self._set_optimized(optimize)
self._parameters = OrderedParameterDict(self)
self._buffers = OrderedBufferDict(self)
self._modules = OrderedModuleDict(self)
def __getattr__(self, attr):
if self._has_method(attr):
return self._get_method(attr)
return Module.__getattr__(self, attr)
def __setattr__(self, attr, value):
if attr not in self._constants_set:
return super(ScriptModule, self).__setattr__(attr, value)
if hasattr(self, attr):
raise RuntimeError("attempting to re-assign constant '{}'".format(attr))
if isinstance(value, ModuleList):
# special case for list of modules. Modules need to be registered with their
# parent module. To do this, we create a ConstModuleList, which is itself a module, that
# contains each of these modules as submodules. The ConstModuleList then
# is set as an attribute of the parent module.
super(ScriptModule, self).__setattr__(attr, _ConstModuleList(value))
else:
super(ScriptModule, self).__setattr__(attr, _get_valid_constant(value))
def __dir__(self):
return sorted(Module.__dir__(self) + self._method_names())
# Module already has this method defined, so we
# need to override it and send it through the ScriptModule lookup
def forward(self, *args, **kwargs):
return self.__getattr__('forward')(*args, **kwargs)
def define(self, lang):
rcb = createResolutionCallback()
self._define(lang, rcb, True)
def _get_methods(cls):
import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods
return inspect.getmembers(cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x))
_compiled_methods_whitelist = {
'forward', 'register_buffer', 'register_parameter', 'add_module',
'_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
'state_dict', 'load_state_dict', '_load_from_state_dict', 'parameters',
'named_parameters', '_all_buffers', 'children', 'named_children', 'modules',
'named_modules', 'zero_grad', 'share_memory', '_get_name'
}
def _make_fail(name):
def fail(self, *args, **kwargs):
raise RuntimeError(name + " is not supported on TracedModules")
return fail
for name, method in _get_methods(torch.nn.Module):
if name.startswith('__'):
continue
if name not in ScriptModule.__dict__ and name not in _compiled_methods_whitelist:
setattr(ScriptModule, method.__name__, _make_fail(name))
class TracedModule(ScriptModule):
__frozen = False
def __init__(self, orig, id_set=None, optimize=True):
super(TracedModule, self).__init__(optimize=True)
if id_set is None:
id_set = set()
def check_unique(param):
if param in id_set:
raise ValueError("TracedModules don't support parameter sharing between modules")
id_set.add(param)
self.training = orig.training
for name, param in orig._parameters.items():
if param is not None:
self._parameters[name] = param
check_unique(param)
for name, buf in orig._buffers.items():
if param is not None:
self._buffers[name] = buf
check_unique(param)
self._orig_class = type(orig)
if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
raise ValueError("Modules that have hooks assigned can't be compiled")
for name, submodule in orig._modules.items():
self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
self._freeze()
def forward(self, *args, **kwargs):
raise RuntimeError('Trace submodules cannot be called.')
def _freeze(self):
self.__frozen = True
def _get_name(self):
return 'TracedModule[' + self._orig_class.__name__ + ']'
def __setattr__(self, attr, value):
if not self.__frozen or hasattr(self, attr):
return super(TracedModule, self).__setattr__(attr, value)
raise RuntimeError("Cannot set new properties on a traced module.")
class TopLevelTracedModule(TracedModule):
def forward(self, *args, **kwargs):
return self._get_method('forward')(*args, **kwargs)
class _ConstModuleList(ScriptModule):
def __init__(self, modules):
super(_ConstModuleList, self).__init__()
for i, module in enumerate(modules):
self.add_module(str(i), module)
def __getitem__(self, idx):
if isinstance(idx, slice):
return _ConstModuleList(list(self._modules.values())[idx])
else:
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
if idx < 0:
idx += len(self)
return self._modules[str(idx)]
def __len__(self):
return len(self._modules)
def __iter__(self):
return iter(self._modules.values())
def __dir__(self):
keys = super(_ConstModuleList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")