blob: c1bdd45a92c9679293af69eb5de5d1de2da4f109 [file] [log] [blame]
import torch._C
import torch._jit_internal as _jit_internal
from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa
from torch._jit_internal import Future
from torch.nn import Module
from torch.utils import set_module
from torch.autograd.grad_mode import _DecoratorContextManager
from typing import Optional, List
import collections
import contextlib
import functools
import os
import pathlib
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method
from torch._jit_internal import ignore, export, unused
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
RecursiveScriptModule, ScriptWarning, interface
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
from torch.jit._async import fork, wait
from torch.jit._serialization import save, load
set_module(Future, "torch.jit")
# For backwards compatibility
_fork = fork
_wait = wait
@contextlib.contextmanager
def optimized_execution(should_optimize):
"""
A context manager that controls whether the JIT's executor will run
optimizations before executing a function.
"""
stored_flag = torch._C._get_graph_executor_optimize()
torch._C._set_graph_executor_optimize(should_optimize)
try:
yield
finally:
torch._C._set_graph_executor_optimize(stored_flag)
@contextlib.contextmanager
def fuser(name):
"""
A context manager that facilitates switching between
backend fusers.
Valid names:
* ``fuser0`` - enables only legacy fuser
* ``fuser1`` - enables only NNC
* ``fuser2`` - enables only nvFuser
"""
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
if name == 'fuser0': # legacy fuser
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser1': # NNC
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser2': # nvFuser
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
else:
raise Exception("unrecognized fuser option")
try:
yield
finally:
if name == 'fuser1': # NNC
torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._jit_set_profiling_mode(old_profiling_mode)
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
def export_opnames(m):
r"""
Returns a list of operator names of a script module and its submodules
"""
return torch._C._export_opnames(m._c)
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
return_inputs=False, _return_inputs_states=False):
"""
.. warning::
This function is internal-only and should only be used by the ONNX
exporter. If you are trying to get a graph through tracing, please go
through the public API instead::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
trace_graph = trace.graph
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value. If return_inputs,
also returns the trace inputs as part of the tuple
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
f (torch.nn.Module or function): the function or module
to be traced.
args (tuple or Tensor): the positional arguments to pass to the
function/module to be traced. A non-tuple is assumed to
be a single positional argument to be passed to the model.
kwargs (dict): the keyword arguments to pass to the function/module
to be traced.
Example (trace a cell):
.. testcode::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
"""
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
return outs
def freeze(mod, preserved_attrs : Optional[List[str]] = None):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
By default, `forward` will be preserved, as well as attributes & methods specified in
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
method will be preserved.
Freezing currently only accepts ScriptModules that are in eval mode.
Arguments:
mod (:class:`ScriptModule`): a module to be frozen
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
Returns:
Frozen :class:`ScriptModule`.
Example (Freezing a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mm(input)
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)
Example (Freezing a module with preserved attributes)
.. testcode::
import torch
class MyModule2(torch.nn.Module):
def __init__(self):
super(MyModule2, self).__init__()
self.modified_tensor = torch.tensor(10.)
self.version = 1
def forward(self, input):
self.modified_tensor += 1
return input + self.modified_tensor
scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
Note:
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError("Freezing expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
if mod.training:
raise RuntimeError("Freezing is currently only implemented for modules in eval mode. "
"Please call .eval() on your module before freezing.")
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
return out
class CompilationUnit(object):
def __init__(self, lang=None, _frames_up=0):
self._c = torch._C.CompilationUnit()
if lang is not None:
self.define(lang, _frames_up=_frames_up + 1)
def define(self, lang, rcb=None, _frames_up=0):
if not rcb:
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
self._c.define(lang, rcb)
def __getattr__(self, attr):
r = self._c.find_function(attr)
if r is None:
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
return r
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
return _jit_internal.boolean_dispatched.get(fn)
def _try_get_overloaded_fn(mod, field):
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
yield
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
def __enter__(self):
self.hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
def __exit__(self, *args):
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
def _script_if_tracing(fn):
"""
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
has a non-negligible start up time when it is first called due to
lazy-initializations of many compiler builtins. Therefore you should not use
it in library code. However, you may want to have parts of your library work
in tracing even if they use control flow. In these cases, you should use
``@torch.jit._script_if_tracing`` to substitute for
``torch.jit.script``.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not is_tracing():
# Not tracing, don't do anything
return fn(*args, **kwargs)
compiled_fn = script(wrapper.__original_fn)
return compiled_fn(*args, **kwargs)
wrapper.__original_fn = fn
wrapper.__script_if_tracing_wrapper = True
return wrapper
def _unwrap_optional(x):
assert x is not None, "Unwrapping null optional"
return x
_register_builtin(_unwrap_optional, 'aten::_unwrap_optional')
_register_builtin(_wait, 'aten::wait')
_register_builtin(wait, 'aten::wait')
_register_builtin(is_scripting, 'aten::is_scripting')
# torch.jit.Error
Error = torch._C.JITException
set_module(Error, "torch.jit")
# This is not perfect but works in common cases
Error.__name__ = "Error"
Error.__qualname__ = "Error"
def _get_named_tuple_properties(obj):
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
fields = list(obj._fields)
annotations = []
has_annotations = hasattr(obj, '__annotations__')
for field in fields:
if has_annotations and field in obj.__annotations__:
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
return type(obj).__name__, fields, annotations
def _create_named_tuple(t, unqual_name, field_names):
TupleType = collections.namedtuple(unqual_name, field_names)
return TupleType(*t)
class _disable_tracing(object):
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)
def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None
# for use in python if using annotate
def annotate(the_type, the_value):
# noop in python
return the_value
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
def _graph_for(self, *args, **kwargs):
self(*args, **kwargs)
return last_executed_optimized_graph()
torch._C.ScriptMethod.graph_for = _graph_for
torch._C.ScriptFunction.graph_for = _graph_for
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")