| """TorchScript |
| |
| This module contains functionality to support the JIT's scripting frontend, notably: |
| - torch.jit.script |
| |
| This is not intended to be imported directly; please use the exposed |
| functionalities in `torch.jit`. |
| """ |
| import functools |
| import collections |
| import enum |
| import inspect |
| import copy |
| import pickle |
| import warnings |
| from typing import Any, Dict, List, Tuple, Union, Callable |
| |
| |
| import torch |
| import torch._jit_internal as _jit_internal |
| from torch.utils import set_module |
| from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile, _compile_and_register_class |
| from torch.nn import Module |
| from torch.jit._state import _enabled |
| from torch.jit._builtins import _register_builtin |
| from torch._six import with_metaclass |
| from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def |
| from torch._jit_internal import _qualified_name |
| from torch.jit._fuser import _graph_for, _script_method_graph_for |
| from torch.jit._state import ( |
| _try_get_jit_cached_function, |
| _try_get_jit_cached_overloads, |
| _set_jit_function_cache, |
| _set_jit_overload_cache, |
| ) |
| from torch.overrides import ( |
| has_torch_function, has_torch_function_unary, has_torch_function_variadic) |
| from torch.package import PackageExporter, PackageImporter |
| from ._serialization import validate_map_location |
| |
| from torch.jit._monkeytype_config import ( |
| monkeytype_trace, |
| JitTypeTraceConfig , |
| JitTypeTraceStore |
| ) |
| from torch._classes import classes |
| |
| type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType |
| |
| torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined] |
| torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined] |
| ScriptFunction = torch._C.ScriptFunction |
| ScriptFunction.__doc__ = """ |
| Functionally equivalent to a :class:`ScriptModule`, but represents a single |
| function and does not have any attributes or Parameters. |
| """ |
| set_module(ScriptFunction, "torch.jit") |
| |
| # Throws an error if a jit function is pickled. |
| # Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument. |
| def _reduce(cls): |
| raise pickle.PickleError("ScriptFunction cannot be pickled") |
| |
| ScriptFunction.__reduce__ = _reduce # type: ignore[assignment] |
| |
| |
| if _enabled: |
| Attribute = collections.namedtuple("Attribute", ["value", "type"]) |
| else: |
| |
| def Attribute(value, type): # type: ignore[no-redef] |
| return value |
| |
| Attribute.__doc__ = """ |
| This method is a pass-through function that returns `value`, mostly |
| used to indicate to the TorchScript compiler that the left-hand side |
| expression is a class instance attribute with type of `type`. Note that |
| `torch.jit.Attribute` should only be used in `__init__` method of `nn.Module` |
| subclasses. |
| |
| Though TorchScript can infer correct type for most Python expressions, there are some cases where |
| type inference can be wrong, including: |
| |
| - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` |
| - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume |
| it is type `T` rather than `Optional[T]` |
| |
| In eager mode, it is simply a pass-through function that returns `value` |
| without other implications. |
| |
| Example: |
| |
| .. testcode:: |
| |
| import torch |
| from typing import Dict |
| |
| class AttributeModule(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.foo = torch.jit.Attribute(0.1, float) |
| |
| # we should be able to use self.foo as a float here |
| assert 0.0 < self.foo |
| |
| self.names_ages = torch.jit.Attribute({}, Dict[str, int]) |
| self.names_ages["someone"] = 20 |
| assert isinstance(self.names_ages["someone"], int) |
| |
| m = AttributeModule() |
| # m will contain two attributes |
| # 1. foo of type float |
| # 2. names_ages of type Dict[str, int] |
| |
| .. testcleanup:: |
| |
| del AttributeModule |
| del m |
| |
| Args: |
| value: An initial value to be assigned to attribute. |
| type: A Python type |
| |
| Returns: |
| Returns `value` |
| """ |
| |
| def _get_type_trace_db(): |
| # This is a private API. Use of this for external purposes is discouraged. |
| return type_trace_db |
| |
| # Gets a function from the name of a method on a type |
| def _get_function_from_type(cls, name): |
| return getattr(cls, name, None) |
| |
| |
| # ScriptClasses must be new-style classes because we construct them using their |
| # __new__ method. |
| def _is_new_style_class(cls): |
| if hasattr(cls, "__class__"): |
| return "__dict__" in dir(cls) or hasattr(cls, "__slots__") |
| |
| |
| # These OrderedDictWrapper classes replace the actual OrderedDicts in |
| # module with versions that get/set properties inside of Module. |
| # This allows us to reuse most of nn.Module while still storing the |
| # data in C++. |
| # Each OrderedDict needs to support: |
| # x not in view |
| # x in view |
| # view[name] = ... |
| # view.values() |
| # del view[name] |
| # view.items() |
| # view.keys() |
| # len(view) |
| |
| |
| class OrderedDictWrapper(object): |
| def __init__(self, _c): |
| self._c = _c |
| |
| def keys(self): |
| return [k for k, v in self.items()] |
| |
| def values(self): |
| return [v for k, v in self.items()] |
| |
| def __len__(self): |
| return len(self.values()) |
| |
| def __delitem__(self, k): |
| raise RuntimeError("cannot delete methods or parameters of a script module") |
| |
| def items(self): |
| return self._c.items() |
| |
| def __setitem__(self, k, v): |
| if k not in self: |
| raise RuntimeError( |
| "Can't add a new parameter after ScriptModule construction." |
| " Tried to add '{}".format(k) |
| ) |
| self._c.setattr(k, v) |
| |
| def __contains__(self, k): |
| return self._c.contains(k) |
| |
| def __getitem__(self, k): |
| if k not in self: |
| raise KeyError(k) |
| return self._c.getattr(k) |
| |
| |
| class OrderedModuleDict(OrderedDictWrapper): |
| def __init__(self, module, python_dict): |
| super(OrderedModuleDict, self).__init__(torch._C.ModuleDict(module)) |
| # contains _both_ script modules and non-script python-only modules |
| |
| # because script modules are subclassed in python and the |
| # C++ Module class will not hold references to them, |
| # to ensure that you always get the same python value here |
| # we store it in the python dict as well |
| self._python_modules = python_dict |
| |
| def items(self): |
| r = self._python_modules.items() |
| return r |
| |
| def __contains__(self, k): |
| return k in self._python_modules |
| |
| def __setitem__(self, k, v): |
| # Cases where sub-module can be re-assigned after ScriptModule construction |
| # 1. If the attr is an module interface type, it's guaranteed that the module is |
| # not inlined in the graph, so it's safe to swap a new ScriptModule in. |
| # 2. if the new value if a ScriptModule with the same JIT type, IR won't change |
| # and it's legit to swap a new module in. |
| # In these two cases we allow swapping a new scripted module and update the |
| # corresponding python module dict to keep sync. |
| # Note: the value to be swapped in has to be ScriptModule instead of nn.Module, |
| # otherwise it's illegal and we throw error. |
| if isinstance(v, ScriptModule): |
| self._c.setattr(k, v) |
| self._python_modules[k] = v |
| else: |
| raise RuntimeError( |
| "Cannot re-assign modules in a ScriptModule with non-scripted " |
| "module, tried to replace existing module '{}': {}".format(k, v) |
| ) |
| |
| def __getitem__(self, k): |
| return self._python_modules[k] |
| |
| |
| # For each user-defined class that subclasses ScriptModule, this meta-class: |
| # (1) finds all the methods annotated with @script_method in a ScriptModule and |
| # removes them from the class attributes |
| # (2) puts a wrapper around the class's __init__ method to recursively compile |
| # all of the script_methods with the module after the original __init__ has |
| # run. This has to occur after the user-defined __init__ so that submodules and |
| # parameters are initialized _before_ the script compiler resolve references to |
| # `self.param` or `self.module`. |
| class ScriptMeta(type): |
| def __init__(cls, name, bases, attrs): # noqa: B902 |
| # Aggregate all the ScriptMethods and constants from superclasses |
| cls._methods: Dict[str, Any] = {} |
| cls._constants_set = set(getattr(cls, "__constants__", ())) |
| for base in reversed(bases): |
| for k, v in getattr(base, "_methods", {}).items(): |
| cls._methods[k] = v |
| base_constants = getattr(base, "_constants_set", set()) |
| cls._constants_set = cls._constants_set.union(base_constants) |
| |
| # find all the script methods of the current class |
| for k, v in sorted(attrs.items()): |
| if isinstance(v, ScriptMethodStub): |
| delattr(cls, k) |
| cls._methods[v.original_method.__name__] = v |
| |
| if getattr(cls, "_disable_script_meta", False): |
| # We leave built-in ScriptModule types alone, since this metaclass |
| # is only for compiling user classes that inherit from |
| # ScriptModule. |
| return super(ScriptMeta, cls).__init__(name, bases, attrs) |
| |
| original_init = getattr(cls, "__init__", lambda self: None) |
| |
| @functools.wraps(original_init) |
| def init_then_script(self, *args, **kwargs): |
| num_methods = len(cls._methods) |
| original_init(self, *args, **kwargs) |
| added_methods_in_init = len(cls._methods) > num_methods |
| |
| if type(self) == cls: |
| |
| def make_stubs(module): |
| cls = type(module) |
| if hasattr(cls, "_methods"): |
| return [v for k, v in sorted(cls._methods.items())] |
| else: |
| return infer_methods_to_compile(module) |
| |
| self.__dict__[ |
| "_actual_script_module" |
| ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init) |
| |
| # Delete the Python attributes that now shadow the ScriptModule |
| # ones, so that __getattr__ and __setattr__ will properly find |
| # the scripted versions. |
| concrete_type = self._actual_script_module._concrete_type |
| for name in concrete_type.get_attributes(): |
| delattr(self, name) |
| for name, _ in concrete_type.get_modules(): |
| delattr(self, name) |
| for name in ("_parameters", "_buffers", "_modules"): |
| delattr(self, name) |
| |
| cls.__init__ = init_then_script # type: ignore[misc] |
| super(ScriptMeta, cls).__init__(name, bases, attrs) |
| |
| |
| class _CachedForward(object): |
| def __get__(self, obj, cls): |
| return self.__getattr__("forward") # type: ignore[attr-defined] |
| |
| |
| class ScriptWarning(Warning): |
| pass |
| |
| |
| def script_method(fn): |
| if not _enabled: |
| return fn |
| # NOTE: we need to traverse two frames here because the meta-class frame |
| # for ScriptModule will be present, as opposed to invoking @script on a |
| # a function or invoking define() on a CompilationUnit. |
| # The stack will look like: |
| # |
| # 0. createResolutionCallback() |
| # 1. script_method() |
| # 2. ScriptModule metaclass frame |
| # 3. Surrounding scope |
| # |
| # createResolutionCallback internally adds 1 to get us to the scope of this |
| # function (the calling function). Adding 2 gets us to the proper surrounding scope. |
| _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) |
| ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") |
| return ScriptMethodStub(_rcb, ast, fn) |
| |
| |
| class ConstMap: |
| def __init__(self, const_mapping): |
| self.const_mapping = const_mapping |
| |
| def __getattr__(self, attr): |
| return self.const_mapping[attr] |
| |
| |
| def unpackage_script_module(importer: PackageImporter, script_module_id: str) -> torch.nn.Module: |
| """ |
| Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. |
| Performs work of loading and returning a ScriptModule from a ``torch.package`` archive. |
| """ |
| if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader): |
| raise RuntimeError( |
| "Loading ScriptObjects from a PackageImporter created from a " |
| "directory is not supported. Use a package archive file instead." |
| ) |
| cu = torch._C.CompilationUnit() |
| cpp_module = torch._C._import_ir_module_from_package( |
| cu, |
| importer.zip_reader, |
| importer.storage_context, |
| validate_map_location(importer.last_map_location), |
| script_module_id, |
| ) |
| return wrap_cpp_module(cpp_module) |
| |
| |
| if _enabled: |
| _magic_methods = [ |
| "__iter__", |
| "__len__", |
| "__neg__", |
| "__mul__", |
| "__contains__", |
| "__add__", |
| "__sub__", |
| "__pow__", |
| "__truediv__", |
| "__mod__", |
| "__ne__", |
| "__eq__", |
| "__lt__", |
| "__gt__", |
| "__le__", |
| "__ge__", |
| "__and__", |
| "__or__", |
| "__xor__", |
| "__getitem__", |
| "__setitem__", |
| "__call__", |
| "__int__", |
| "__float__", |
| "__bool__", |
| "__str__", |
| "__enter__", |
| "__exit__", |
| ] |
| |
| class RecursiveScriptClass(object): |
| """ |
| An analogue of RecursiveScriptModule for regular objects that are not modules. |
| This class is a wrapper around a torch._C.ScriptObject that represents an instance |
| of a TorchScript class and allows it to be used in Python. |
| |
| Attributes: |
| _c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method |
| calls are forwarded. |
| _props [Dict[str, property]]: A dictionary of properties fetched from self._c and |
| exposed on this wrppaer. |
| """ |
| def __init__(self, cpp_class): |
| super(RecursiveScriptClass, self).__init__() |
| self.__dict__["_initializing"] = True |
| self._c = cpp_class |
| |
| # Add wrapped object's properties to this class instance. |
| self._props = {prop.name: property(prop.getter, prop.setter) for prop in self._c._properties()} |
| |
| self.__dict__["_initializing"] = False |
| |
| def __getattr__(self, attr): |
| if "_initializing" in self.__dict__ and self.__dict__["_initializing"]: |
| return super(RecursiveScriptClass, self).__getattr__(attr) # type: ignore[misc] |
| |
| if attr in self._props: |
| return self._props[attr].fget() |
| |
| return getattr(self._c, attr) |
| |
| def __setattr__(self, attr, value): |
| if "_initializing" in self.__dict__ and self.__dict__["_initializing"]: |
| return super(RecursiveScriptClass, self).__setattr__(attr, value) |
| |
| if attr in self._props: |
| return self._props[attr].fset(value) |
| |
| setattr(self._c, attr, value) |
| |
| # Delegate calls to magic methods like __len__ to the C++ module backing the |
| # RecursiveScriptClass. |
| def forward_magic_method(self, method_name, *args, **kwargs): |
| if not self._c._has_method(method_name): |
| raise TypeError() |
| |
| self_method = self.__getattr__(method_name) |
| return self_method(*args, **kwargs) |
| |
| def __getstate__(self): |
| raise pickle.PickleError("ScriptClasses cannot be pickled") |
| |
| def __iadd__(self, other): |
| if self._c._has_method("__iadd__"): |
| return self.forward_magic_method("__iadd__", other) |
| else: |
| return self.forward_magic_method("__add__", other) |
| |
| |
| for method_name in _magic_methods: |
| def method_template(self, *args, **kwargs): |
| return self.forward_magic_method(method_name, *args, **kwargs) |
| |
| setattr(RecursiveScriptClass, method_name, method_template) |
| |
| # this is a Python 'non-data descriptor' that causes the first access |
| # to ScriptModule's forward to look up the forward method and stash |
| # it in the objects dict. Due to the standard rules for attribute lookup, |
| # subsequent lookups will just directly return the previously looked up method. |
| # This is necessary because nn.Module defines forward as a method. If we |
| # did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward |
| # which always throws an exception. |
| |
| class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore[misc] |
| r""" |
| A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s |
| contain methods, attributes, parameters, and |
| constants. These can be accessed the same way as on a normal ``nn.Module``. |
| """ |
| __jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name'] |
| |
| def __init__(self): |
| super(ScriptModule, self).__init__() |
| |
| forward = _CachedForward() |
| |
| def __getattr__(self, attr): |
| if "_actual_script_module" not in self.__dict__: |
| return super(ScriptModule, self).__getattr__(attr) |
| return getattr(self._actual_script_module, attr) |
| |
| def __setattr__(self, attr, value): |
| if "_actual_script_module" not in self.__dict__: |
| # Unwrap torch.jit.Attribute into a regular setattr + record |
| # the provided type in __annotations__. |
| # |
| # This ensures that if we use the attr again in `__init__`, it |
| # will look like the actual value, not an instance of Attribute. |
| if isinstance(value, Attribute): |
| # NB: Ensure that we set __annotations__ on the specific |
| # class in question, and not on a superclass (which would |
| # be wrong wrong wrong!). |
| # See also https://github.com/pytorch/pytorch/issues/39463 |
| if "__annotations__" not in self.__class__.__dict__: |
| self.__class__.__annotations__ = {} |
| self.__annotations__[attr] = value.type |
| value = value.value |
| return super(ScriptModule, self).__setattr__(attr, value) |
| |
| setattr(self._actual_script_module, attr, value) |
| |
| def define(self, src): |
| if "_actual_script_module" in self.__dict__: |
| # If we have completed initialization, just defer to the |
| # backing RecursiveScriptModule to eagerly compile the provided |
| # source. |
| return self._actual_script_module.define(src) |
| |
| # Otherwise, we are still in the object's __init__. |
| # In that case, add `src` as a stub to be compiled. |
| # |
| # We use frames_up=1 to get to the proper surrounding scope. The stack |
| # will look like: |
| # 0. createResolutionCallback |
| # 1. define() |
| # 2. surrounding scope. |
| # |
| # createResolutionCallback internally adds 1 to get us to our frame, then |
| # we add 1 to get to the proper surrounding scope. |
| rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) |
| ast = torch._C._parse_source_def(src) |
| self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None) |
| |
| def _replicate_for_data_parallel(self): |
| return self._actual_script_module._replicate_for_data_parallel() |
| |
| def __reduce_package__(self, exporter: PackageExporter): |
| """ |
| Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when |
| saving TorchScript objects. Performs act of saving a ScriptModule inside of |
| a ``torch.package`` archive. |
| |
| Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s |
| Pickler's ``persistent_load`` function. |
| """ |
| script_module_id = exporter.get_unique_id() |
| exporter.script_module_serializer.serialize(self._c, int(script_module_id)) |
| return (unpackage_script_module, (script_module_id,)) |
| |
| class RecursiveScriptModule(ScriptModule): |
| # XXX: RecursiveScriptModule inherits from ScriptModule for the sole |
| # reason that it retains the existing isinstance(ScriptModule) |
| # behavior. |
| r""" |
| The core data structure in TorchScript is the ``ScriptModule``. It is an |
| analogue of torch's ``nn.Module`` and represents an entire model as a tree of |
| submodules. Like normal modules, each individual module in a ``ScriptModule`` can |
| have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented |
| as Python functions, but in ``ScriptModule``\s methods are implemented as |
| TorchScript functions, a statically-typed subset of Python that contains all |
| of PyTorch's built-in Tensor operations. This difference allows your |
| ``ScriptModule``\s code to run without the need for a Python interpreter. |
| |
| ``ScriptModule``\s should not be created manually, instead use |
| either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`. |
| Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`. |
| |
| * Tracing records the tensor operations as executed with a set of example inputs and uses these |
| operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing, |
| but values other than Tensors and control flow aren't captured in the graph. |
| |
| * Scripting inspects the Python code of the model |
| and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow. |
| Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary. |
| """ |
| _disable_script_meta = True |
| |
| def __init__(self, cpp_module): |
| self.__dict__["_initializing"] = True |
| self._c = cpp_module |
| super(RecursiveScriptModule, self).__init__() |
| # Delete the 'training' attribute set up by `Module.__init__`. It |
| # will get set on the underlying cpp module, so we delete it here |
| # to avoid this version shadowing the cpp module version. |
| delattr(self, "training") |
| |
| @staticmethod |
| def _construct(cpp_module, init_fn): |
| """ |
| Construct a RecursiveScriptModule that's ready for use. PyTorch |
| code should use this to construct a RecursiveScriptModule instead |
| of instead of calling `__init__` directly, as it makes sure the |
| object is properly finalized (and in the future, we may take |
| control of how the RecursiveScriptModule instance is created). |
| |
| Args: |
| cpp_module: The C++ Module that will hold the actual state of |
| this RecursiveScriptModule instance. |
| init_fn: Lambda that initializes the RecursiveScriptModule passed to it. |
| """ |
| script_module = RecursiveScriptModule(cpp_module) |
| init_fn(script_module) |
| |
| # Finalize the ScriptModule: replace the nn.Module state with our |
| # custom implementations and flip the _initializing bit. |
| RecursiveScriptModule._finalize_scriptmodule(script_module) |
| return script_module |
| |
| @staticmethod |
| def _finalize_scriptmodule(script_module): |
| script_module._parameters = OrderedDictWrapper( |
| torch._C.ParameterDict(script_module._c) |
| ) |
| script_module._buffers = OrderedDictWrapper( |
| torch._C.BufferDict(script_module._c) |
| ) |
| script_module._modules = OrderedModuleDict( |
| script_module._c, script_module._modules |
| ) |
| script_module._initializing = False |
| |
| def _reconstruct(self, cpp_module): |
| """ |
| Re-construct an instance of RecursiveScriptModule using an instance of a C++ module. |
| |
| Args: |
| cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around. |
| """ |
| self.__init__(cpp_module) # type: ignore[misc] |
| |
| # Copy the concrete type from the C++ module to this ScriptModule. |
| self._concrete_type = torch._C.ConcreteModuleType.from_jit_type( |
| self._c._type() |
| ) |
| |
| # Copy submodules from the C++ module to this ScriptModule. |
| modules = {} |
| for name, cpp_module in torch._C.ModuleDict(self._c).items(): |
| modules[name] = wrap_cpp_module(cpp_module) |
| self._modules = OrderedModuleDict(self._c, modules) |
| |
| # Copy parameters and buffers. |
| self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) |
| self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) |
| |
| # Get rid of the functions from the old C++ module. |
| self.__dict__ = { |
| k: v |
| for k, v in self.__dict__.items() |
| if not isinstance(v, torch._C.ScriptMethod) |
| } |
| self.__dict__["_initializing"] = False |
| |
| @property |
| def graph(self): |
| r""" |
| Returns a string representation of the internal graph for the |
| ``forward`` method. See :ref:`interpreting-graphs` for details. |
| """ |
| return self._c._get_method("forward").graph |
| |
| @property |
| def inlined_graph(self): |
| r""" |
| Returns a string representation of the internal graph for the |
| ``forward`` method. This graph will be preprocessed to inline all function and method calls. |
| See :ref:`interpreting-graphs` for details. |
| """ |
| return self.forward.inlined_graph |
| |
| @property |
| def code(self): |
| r""" |
| Returns a pretty-printed representation (as valid Python syntax) of |
| the internal graph for the ``forward`` method. See |
| :ref:`inspecting-code` for details. |
| """ |
| return self.forward.code |
| |
| @property |
| def code_with_constants(self): |
| r""" |
| Returns a tuple of: |
| |
| [0] a pretty-printed representation (as valid Python syntax) of |
| the internal graph for the ``forward`` method. See `code`. |
| [1] a ConstMap following the CONSTANT.cN format of the output in [0]. |
| The indices in the [0] output are keys to the underlying constant's values. |
| |
| See :ref:`inspecting-code` for details. |
| """ |
| r = self.forward.code_with_constants |
| return (r[0], ConstMap(r[1])) |
| |
| def save(self, f, **kwargs): |
| r""" |
| save(f, _extra_files={}) |
| |
| See :func:`torch.jit.save <torch.jit.save>` for details. |
| """ |
| return self._c.save(str(f), **kwargs) |
| |
| def _save_for_lite_interpreter(self, *args, **kwargs): |
| r""" |
| _save_for_lite_interpreter(f) |
| |
| Add (or update) the bytecode session to the script model. The updated model is used |
| in lite interpreter for mobile applications. |
| |
| Args: |
| f: a string containing a file name. |
| _extra_files: Map from filename to contents which will be stored as part of 'f'. |
| |
| """ |
| return self._c._save_for_mobile(*args, **kwargs) |
| |
| def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): |
| return self._c._save_to_buffer_for_mobile(*args, **kwargs) |
| |
| def save_to_buffer(self, *args, **kwargs): |
| return self._c.save_to_buffer(*args, **kwargs) |
| |
| def get_debug_state(self, *args, **kwargs): |
| return self._c.get_debug_state() |
| |
| def extra_repr(self): |
| return "original_name={}".format(self.original_name) |
| |
| def graph_for(self, *args, **kwargs): |
| return self.forward.graph_for(self, *args, **kwargs) |
| |
| @property |
| def original_name(self): |
| if type(self) == str(self._c._type().name()): |
| return "" |
| return str(self._c._type().name()) |
| |
| def define(self, src): |
| # We use frames_up=1 to get to the proper surrounding scope. The stack |
| # will look like: |
| # 0. createResolutionCallback |
| # 1. define() |
| # 2. surrounding scope. |
| # |
| # createResolutionCallback internally adds 1 to get us to our frame, then |
| # we add 1 to get to the proper surrounding scope. |
| rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) |
| self._c._define(self._concrete_type, src, rcb) |
| |
| def __getattr__(self, attr): |
| if "_initializing" not in self.__dict__: |
| raise RuntimeError( |
| "ScriptModule has not been initialized, did you forget to call super's init?" |
| ) |
| |
| if self._initializing: |
| return super(RecursiveScriptModule, self).__getattr__(attr) |
| |
| # _modules check is before hasattr since modules are included as attributes in _c, |
| # but we want to get the python wrapper from _modules instead of the raw _c object. |
| if attr in self._modules: |
| return self._modules[attr] |
| elif self._c.hasattr(attr): |
| return self._c.getattr(attr) |
| elif self._c._has_method(attr): |
| script_method = self._c._get_method(attr) |
| # cache method so future calls do not go through __getattr__ |
| # to improve invocation performance |
| self.__dict__[attr] = script_method |
| return script_method |
| |
| return super(RecursiveScriptModule, self).__getattr__(attr) |
| |
| def __setattr__(self, attr, value): |
| if self._initializing: |
| return super(RecursiveScriptModule, self).__setattr__(attr, value) |
| |
| if attr in self._modules: |
| self._modules[attr] = value |
| elif self._c.hasattr(attr): |
| self._c.setattr(attr, value) |
| elif ( |
| hasattr(self, "_concrete_type") |
| and attr in self._concrete_type.get_constants().keys() |
| ): |
| # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. |
| # We should encode constants as class type attributes (or something) so it persists across save/load. |
| raise AttributeError( |
| "Cannot mutate TorchScript constant value: '{}'. Value: '{}'".format( |
| attr, value |
| ) |
| ) |
| else: |
| # We allow setting Python attributes on the ScriptModule, for |
| # when people want to stash some convenience info on it. |
| # TODO: it's possible that the following is confusing: |
| # s = torch.jit.script(...) |
| # s.python_attr = ... |
| # s.save() <--- this doesn't have `python_attr` |
| # It's fairly trivial to save enough info to warn in this case. |
| return super(RecursiveScriptModule, self).__setattr__(attr, value) |
| |
| def __copy__(self): |
| return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) |
| |
| def __deepcopy__(self, memo): |
| return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo)) |
| |
| # Python magic methods do method lookups on an object's class type, instead of looking up |
| # the method defines on the class instance. In order to continue to expose the magic methods |
| # of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we |
| # define magic methods here as a shim to the correct attribute. |
| def forward_magic_method(self, method_name, *args, **kwargs): |
| self_method = getattr(self, method_name) |
| if getattr(self_method, "__func__", None) == getattr( |
| RecursiveScriptModule, method_name |
| ): |
| raise NotImplementedError() |
| return self_method(*args, **kwargs) |
| |
| def __iter__(self): |
| return self.forward_magic_method("__iter__") |
| |
| def __getitem__(self, idx): |
| return self.forward_magic_method("__getitem__", idx) |
| |
| def __len__(self): |
| return self.forward_magic_method("__len__") |
| |
| def __contains__(self, key): |
| return self.forward_magic_method("__contains__", key) |
| |
| # dir is defined by the base nn.Module, so instead of throwing if |
| # it is not overridden, we call into the nn.Module __dir__ method |
| def __dir__(self): |
| self_method = self.__dir__ |
| if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined] |
| RecursiveScriptModule, "__dir__" |
| ): |
| return super(RecursiveScriptModule, self).__dir__() |
| return self_method() |
| |
| # to resolve bool(value), Python looks if __bool__ is defined then __iter__ |
| # is defined then returns true for classes. Since __iter__() on this |
| # class throws if it isn't overridden, we define __bool__ to preserve default behavior |
| def __bool__(self): |
| self_method = self.__bool__ |
| if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined] |
| RecursiveScriptModule, "__bool__" |
| ): |
| return True |
| return self_method() |
| |
| def _replicate_for_data_parallel(self): |
| # we have to initialize ScriptModule properly so that |
| # it works with pybind11 |
| def init_fn(script_module): |
| # Don't do anything here, we'll initialize the ScriptModule below |
| return |
| |
| return RecursiveScriptModule._construct( |
| self._c._replicate_for_data_parallel(), init_fn |
| ) |
| |
| # Need to copy all RecursiveScriptModule methods to ScriptModule. |
| # |
| # This is because `super(MyScriptModule, self).foo()` does not use |
| # `__getattr__` to look up `foo`. So we need to make each method available on |
| # the ScriptModule manually. |
| for name, item in RecursiveScriptModule.__dict__.items(): |
| if not callable(item) and not isinstance(item, property): |
| continue |
| if name.startswith("__") or hasattr(ScriptModule, name): |
| continue |
| # We can copy over the implementation wholesale because besides the |
| # `super()` thing above, ScriptModule behaves exactly like |
| # RecursiveScriptModule |
| setattr(ScriptModule, name, item) |
| |
| def _get_methods(cls): |
| import inspect |
| |
| # In Python 3 unbound methods are functions, but in Python 2 they are methods |
| return inspect.getmembers( |
| cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) |
| ) |
| |
| _compiled_methods_allowlist = { |
| "forward", |
| "register_buffer", |
| "register_parameter", |
| "register_module", |
| "add_module", |
| "_apply", |
| "apply", |
| "cuda", |
| "cpu", |
| "to", |
| "type", |
| "float", |
| "double", |
| "half", |
| "state_dict", |
| "_state_dict_impl", |
| "_save_to_state_dict", |
| "load_state_dict", |
| "_load_from_state_dict", |
| "_named_members", |
| "parameters", |
| "named_parameters", |
| "buffers", |
| "named_buffers", |
| "children", |
| "named_children", |
| "modules", |
| "named_modules", |
| "zero_grad", |
| "share_memory", |
| "_get_name", |
| "extra_repr", |
| "_slow_forward", |
| "_tracing_name", |
| "eval", |
| "train", |
| "get_extra_state", |
| "set_extra_state" |
| } |
| |
| def _make_fail(name): |
| def fail(self, *args, **kwargs): |
| raise RuntimeError(name + " is not supported on ScriptModules") |
| |
| return fail |
| |
| for name, method in _get_methods(torch.nn.Module): |
| if name.startswith("__"): |
| continue |
| if ( |
| name not in RecursiveScriptModule.__dict__ |
| and name not in _compiled_methods_allowlist |
| ): |
| setattr(RecursiveScriptModule, method.__name__, _make_fail(name)) |
| |
| |
| else: |
| # TODO MAKE SURE THAT DISABLING WORKS |
| class RecursiveScriptClass(object): # type: ignore[no-redef] |
| def __init__(self): |
| super().__init__() |
| |
| class ScriptModule(torch.nn.Module): # type: ignore[no-redef] |
| def __init__(self, arg=None): |
| super().__init__() |
| |
| class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef] |
| def __init__(self, arg=None): |
| super().__init__() |
| |
| def call_prepare_scriptable_func_impl(obj, memo): |
| if not isinstance(obj, torch.nn.Module): |
| return obj |
| |
| obj_id = id(obj) |
| |
| # If obj_id is in memo, obj has already been prepared or is being |
| # prepared in another call up the stack. |
| if obj_id in memo: |
| return memo[id(obj)] |
| |
| obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator] |
| # Record obj in memo to avoid infinite recursion in the case of cycles in the module |
| # hierarchy when recursing below. |
| memo[obj_id] = obj |
| |
| new_obj_dict = {} |
| |
| for name, sub_module in obj.__dict__.items(): |
| if name == '_modules': |
| for k, v in sub_module.items(): |
| sub_module[k] = call_prepare_scriptable_func_impl(v, memo) |
| new_obj_dict[name] = sub_module |
| elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule): |
| new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo) |
| else: |
| new_obj_dict[name] = sub_module |
| |
| for k, v in new_obj_dict.items(): |
| obj.__dict__[name] = v |
| |
| return obj |
| |
| |
| def call_prepare_scriptable_func(obj): |
| memo: Dict[int, torch.nn.Module] = {} |
| return call_prepare_scriptable_func_impl(obj, memo) |
| |
| def create_script_dict(obj): |
| """ |
| Create a ``torch._C.ScriptDict`` instance with the data from ``obj``. |
| |
| Args: |
| obj (dict): The Python dictionary that is used to initialize the ``ScriptDict`` |
| returned by this function. |
| |
| Returns: |
| An instance of ``torch._C.ScriptDict`` that has the same data as ``obj`` |
| and can be passed between Python and TorchScript with reference semantics and |
| zero copy overhead. |
| """ |
| return torch._C.ScriptDict(obj) # type: ignore[attr-defined] |
| |
| |
| def create_script_list(obj, type_hint=None): |
| """ |
| Create a ``torch._C.ScriptList`` instance with the data from ``obj``. |
| Args: |
| obj (dict): The Python list that is used to initialize the ``ScriptList`` |
| returned by this function. |
| Returns: |
| An instance of ``torch._C.ScriptList`` that has the same data as ``obj`` |
| and can be passed between Python and TorchScript with reference semantics and |
| zero copy overhead. |
| """ |
| return torch._C.ScriptList(obj) # type: ignore[attr-defined] |
| |
| |
| def script(obj, optimize=None, _frames_up=0, _rcb=None, |
| example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None): |
| r""" |
| Scripting a function or ``nn.Module`` will inspect the source code, compile |
| it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or |
| :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all |
| features in Python work, but we provide enough functionality to compute on |
| tensors and do control-dependent operations. For a complete guide, see the |
| :ref:`language-reference`. |
| |
| Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be |
| subsequently passed by reference between Python and TorchScript with zero copy overhead. |
| |
| ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists |
| and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. |
| |
| Args: |
| obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type, |
| dictionary, or list to compile. |
| example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs |
| to annotate the arguments for a function or ``nn.Module``. |
| |
| Returns: |
| If ``obj`` is ``nn.Module``, ``script`` returns |
| a :class:`ScriptModule` object. The returned :class:`ScriptModule` will |
| have the same set of sub-modules and parameters as the |
| original ``nn.Module``. If ``obj`` is a standalone function, |
| a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then |
| ``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``, |
| then ``script`` returns an instance of `torch._C.ScriptList`. |
| |
| **Scripting a function** |
| The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` |
| by compiling the body of the function. |
| |
| Example (scripting a function): |
| |
| .. testcode:: |
| |
| import torch |
| |
| @torch.jit.script |
| def foo(x, y): |
| if x.max() > y.max(): |
| r = x |
| else: |
| r = y |
| return r |
| |
| print(type(foo)) # torch.jit.ScriptFunction |
| |
| # See the compiled graph as Python code |
| print(foo.code) |
| |
| # Call the function using the TorchScript interpreter |
| foo(torch.ones(2, 2), torch.ones(2, 2)) |
| |
| .. testoutput:: |
| :hide: |
| |
| ... |
| |
| ****Scripting a function using example_inputs** |
| Example inputs can be used to annotate a function arguments. |
| |
| Example (annotating a function before scripting): |
| |
| .. testcode:: |
| |
| import torch |
| |
| def test_sum(a, b): |
| return a + b |
| |
| # Annotate the arguments to be int |
| scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) |
| |
| print(type(scripted_fn)) # torch.jit.ScriptFunction |
| |
| # See the compiled graph as Python code |
| print(scripted_fn.code) |
| |
| # Call the function using the TorchScript interpreter |
| scripted_fn(20, 100) |
| |
| .. testoutput:: |
| :hide: |
| |
| ... |
| |
| **Scripting an nn.Module** |
| Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively |
| compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses |
| features supported in TorchScript, no changes to the original module code should be necessary. ``script`` |
| will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of |
| the original module. |
| |
| Example (scripting a simple module with a Parameter): |
| |
| .. testcode:: |
| |
| import torch |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self, N, M): |
| super(MyModule, self).__init__() |
| # This parameter will be copied to the new ScriptModule |
| self.weight = torch.nn.Parameter(torch.rand(N, M)) |
| |
| # When this submodule is used, it will be compiled |
| self.linear = torch.nn.Linear(N, M) |
| |
| def forward(self, input): |
| output = self.weight.mv(input) |
| |
| # This calls the `forward` method of the `nn.Linear` module, which will |
| # cause the `self.linear` submodule to be compiled to a `ScriptModule` here |
| output = self.linear(output) |
| return output |
| |
| scripted_module = torch.jit.script(MyModule(2, 3)) |
| |
| Example (scripting a module with traced submodules): |
| |
| .. testcode:: |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| class MyModule(nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| # torch.jit.trace produces a ScriptModule's conv1 and conv2 |
| self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) |
| self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) |
| |
| def forward(self, input): |
| input = F.relu(self.conv1(input)) |
| input = F.relu(self.conv2(input)) |
| return input |
| |
| scripted_module = torch.jit.script(MyModule()) |
| |
| To compile a method other than ``forward`` (and recursively compile anything it calls), add |
| the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation |
| use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`. |
| |
| Example (an exported and ignored method in a module):: |
| |
| import torch |
| import torch.nn as nn |
| |
| class MyModule(nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| |
| @torch.jit.export |
| def some_entry_point(self, input): |
| return input + 10 |
| |
| @torch.jit.ignore |
| def python_only_fn(self, input): |
| # This function won't be compiled, so any |
| # Python APIs can be used |
| import pdb |
| pdb.set_trace() |
| |
| def forward(self, input): |
| if self.training: |
| self.python_only_fn(input) |
| return input * 99 |
| |
| scripted_module = torch.jit.script(MyModule()) |
| print(scripted_module.some_entry_point(torch.randn(2, 2))) |
| print(scripted_module(torch.randn(2, 2))) |
| |
| Example ( Annotating forward of nn.Module using example_inputs):: |
| |
| import torch |
| import torch.nn as nn |
| from typing import NamedTuple |
| |
| class MyModule(NamedTuple): |
| result: List[int] |
| |
| class TestNNModule(torch.nn.Module): |
| def forward(self, a) -> MyModule: |
| result = MyModule(result=a) |
| return result |
| |
| pdt_model = TestNNModule() |
| |
| # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward |
| scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) |
| |
| # Run the scripted_model with actual inputs |
| print(scripted_model([20])) |
| """ |
| global type_trace_db |
| if not _enabled: |
| return obj |
| |
| if optimize is not None: |
| warnings.warn( |
| "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" |
| ) |
| |
| # No-op for modules, functions, class instances that are already scripted |
| if isinstance(obj, RecursiveScriptClass): |
| return obj |
| if isinstance(obj, ScriptModule): |
| return obj |
| if isinstance(obj, ScriptFunction): |
| return obj |
| |
| if example_inputs: |
| # If MonkeyType is installed, enable profile directed type annotation |
| # Check if example_inputs are defined and generate call traces |
| # for the method by running eager mode version of the method with |
| # the provide example inputs. This logs all the traces in type_trace_db |
| type_trace_db = JitTypeTraceStore() |
| if monkeytype_trace: |
| monkeytype_config = JitTypeTraceConfig(type_trace_db) |
| with monkeytype_trace(monkeytype_config): |
| if isinstance(example_inputs, Dict): |
| # If the obj is an nn.Module or a class, then each method is |
| # executed with the arguments provided in the example inputs. |
| # example inputs here will be of type Dict(class.method, (arguments)) |
| # This is used to infer type annotations for those methods |
| # which are not called directly under the hood of monkeytype. |
| for module, example_input in example_inputs.items(): |
| for example in example_input: |
| module(*example) |
| elif isinstance(example_inputs, List): |
| for examples in example_inputs: |
| obj(*examples) |
| else: |
| raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" |
| " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.") |
| else: |
| warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " |
| "to enable Profile-Directed Typing in TorchScript. Refer to " |
| "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ") |
| |
| if isinstance(obj, torch.nn.Module): |
| obj = call_prepare_scriptable_func(obj) |
| return torch.jit._recursive.create_script_module( |
| obj, torch.jit._recursive.infer_methods_to_compile |
| ) |
| |
| if isinstance(obj, dict): |
| return create_script_dict(obj) |
| if isinstance(obj, list): |
| return create_script_list(obj) |
| |
| if inspect.isclass(obj): |
| qualified_name = _qualified_name(obj) |
| # If this type is a `nn.Module` subclass, they probably meant to pass |
| # an instance instead of a Module |
| if issubclass(obj, torch.nn.Module): |
| raise RuntimeError( |
| "Type '{}' cannot be compiled since it inherits" |
| " from nn.Module," |
| " pass an instance instead".format(obj) |
| ) |
| |
| # Enums are automatically usable in TorchScript, explicitly scripting |
| # is not necessary, but not harmful either. |
| if issubclass(obj, enum.Enum): |
| return obj |
| |
| if not _is_new_style_class(obj): |
| raise RuntimeError( |
| "TorchScript classes must be new-style classes. " |
| "Please inherit from 'object'." |
| ) |
| if len(obj.mro()) > 2: |
| raise RuntimeError( |
| "TorchScript classes does not support inheritance yet. " |
| "Please directly inherit from 'object'." |
| ) |
| if _rcb is None: |
| _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) |
| _compile_and_register_class(obj, _rcb, qualified_name) |
| return obj |
| elif inspect.isfunction(obj) or inspect.ismethod(obj): |
| qualified_name = _qualified_name(obj) |
| # this is a decorated fn, and we need to the underlying fn and its rcb |
| if hasattr(obj, "__script_if_tracing_wrapper"): |
| obj = obj.__original_fn |
| _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) |
| |
| # some functions are explicitly marked as not supported in script mode |
| if hasattr(obj, "__script_unsupported"): |
| raise RuntimeError("TorchScript error: " + obj.__script_unsupported) |
| |
| _check_directly_compile_overloaded(obj) |
| maybe_already_compiled_fn = _try_get_jit_cached_function(obj) |
| if maybe_already_compiled_fn: |
| return maybe_already_compiled_fn |
| ast = get_jit_def(obj, obj.__name__) |
| if _rcb is None: |
| _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) |
| fn = torch._C._jit_script_compile( |
| qualified_name, ast, _rcb, get_default_args(obj) |
| ) |
| # Forward docstrings |
| fn.__doc__ = obj.__doc__ |
| _set_jit_function_cache(obj, fn) |
| return fn |
| else: |
| return torch.jit._recursive.create_script_class(obj) |
| |
| |
| # overloads are registered in _jit_internal and compiled here so that _overload |
| # can be used in nn/functional.py without an import cycle |
| |
| |
| def _check_overload_defaults(impl_defaults, overload_defaults, loc): |
| for name, overload_value in overload_defaults.items(): |
| if name not in impl_defaults or impl_defaults[name] != overload_value: |
| raise torch.jit.frontend.FrontendError( |
| loc, |
| "Default parameters on overloads do not affect the runtime so they " |
| "must equal to the default parameter on the implementation function. Found on " |
| "parameter {name}".format(name=name), |
| ) |
| |
| |
| def _compile_function_with_overload(overload_fn, qual_name, impl_fn): |
| overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() |
| overload_signature = torch.jit.annotations.get_signature( |
| overload_fn, None, None, inspect.ismethod(overload_fn) |
| ) |
| impl_ast = get_jit_def(impl_fn, impl_fn.__name__) |
| overload_defaults = get_default_args(overload_fn) |
| implementation_defaults = get_default_args(impl_fn) |
| _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) |
| _check_overload_defaults( |
| implementation_defaults, overload_defaults, overload_decl.range() |
| ) |
| fn = torch._C._jit_script_compile_overload( |
| qual_name, |
| overload_decl, |
| impl_ast, |
| _rcb, |
| implementation_defaults, |
| overload_signature, |
| ) |
| return fn |
| |
| |
| def _get_overloads(obj): |
| # check for cached compiled fns |
| existing_compiled_fns = _try_get_jit_cached_overloads(obj) |
| qual_name = _qualified_name(obj) |
| uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) |
| if uncompiled_overloads is None: |
| return existing_compiled_fns |
| |
| if obj in uncompiled_overloads: |
| raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message( |
| 'function', obj)) |
| |
| compiled_fns = [] |
| for overload_fn in uncompiled_overloads: |
| compiled_fns.append( |
| _compile_function_with_overload(overload_fn, qual_name, obj) |
| ) |
| |
| if existing_compiled_fns: |
| compiled_fns = existing_compiled_fns + compiled_fns |
| |
| # cache compilation, remove information stored to do compilation |
| _set_jit_overload_cache(obj, compiled_fns) |
| _jit_internal._clear_fn_overloads(qual_name) |
| return compiled_fns |
| |
| |
| def _check_directly_compile_overloaded(obj): |
| qual_name = _qualified_name(obj) |
| if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): |
| raise RuntimeError( |
| "Function {} cannot be directly compiled because it" |
| " is overloaded. It must be used in a context of a function" |
| " where its inputs can determine which overload to call.".format(qual_name) |
| ) |
| |
| |
| def interface(obj): |
| if not inspect.isclass(obj): |
| raise RuntimeError("interface must be applied to a class") |
| if not _is_new_style_class(obj): |
| raise RuntimeError("TorchScript interfaces must inherit from 'object'") |
| |
| # Expected MRO is: |
| # User module |
| # torch.nn.modules.module.Module |
| # object |
| is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3 |
| |
| if not is_module_interface and len(obj.mro()) > 2: |
| raise RuntimeError( |
| "TorchScript interface does not support inheritance yet. " |
| "Please directly inherit from 'object' or 'nn.Module'." |
| ) |
| |
| qualified_name = _qualified_name(obj) |
| rcb = _jit_internal.createResolutionCallbackFromFrame(1) |
| # if this type is a `nn.Module` subclass, generate a module interface type |
| # instead of a class interface type; a module interface type only compiles |
| # the user provided methods as part of the interface |
| ast = get_jit_class_def(obj, obj.__name__) |
| mangled_classname = torch._C._jit_script_interface_compile( |
| qualified_name, ast, rcb, is_module_interface |
| ) |
| obj.__torch_script_interface__ = mangled_classname |
| return obj |
| |
| |
| def _recursive_compile_class(obj, loc): |
| _qual_name = _qualified_name(obj) |
| # We're starting a new compilation, so update the error call stack in |
| # case it fails |
| error_stack = torch._C.CallStack(_qual_name, loc) |
| rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) |
| return _compile_and_register_class(obj, rcb, _qual_name) |
| |
| CompilationUnit = torch._C.CompilationUnit |
| set_module(CompilationUnit, "torch.jit") |
| |
| |
| def pad(s: str, padding: int, offset: int = 0, char: str = ' '): |
| if padding >= len(s): |
| padding -= len(s) |
| return ''.join([char for _ in range(padding + offset)]) + s |
| |
| |
| class _ScriptProfileColumn: |
| def __init__(self, header: str, alignment: int = 4, offset: int = 0): |
| self.header = header |
| self.alignment = alignment |
| self.offset = offset |
| self.rows: Dict[int, Any] = {} |
| |
| def add_row(self, lineno: int, value: Any): |
| self.rows[lineno] = value |
| |
| def materialize(self): |
| max_length = len(self.header) |
| rows: List[Tuple[int, str]] = [] |
| for (key, value) in self.rows.items(): |
| cell = str(value) |
| rows.append((key, cell)) |
| max_length = max(len(cell), max_length) |
| |
| if self.alignment > 0: |
| padding = max_length + self.alignment |
| padding -= padding % self.alignment |
| else: |
| padding = 0 |
| |
| rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows] |
| return pad(self.header, padding, self.offset), rows |
| |
| |
| class _ScriptProfileTable: |
| def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]): |
| self.cols = cols |
| self.source_range = source_range |
| |
| def dump_string(self): |
| outputs: List[str] = [] |
| cells: List[Tuple[str, Dict[int, str]]] = [] |
| header_buffer = '' |
| for col in self.cols: |
| header, rows = col.materialize() |
| header_buffer += header |
| cells.append((header, dict(rows))) |
| |
| outputs.append(header_buffer) |
| outputs.append(pad('', len(header_buffer), 0, '=')) |
| for line in self.source_range: |
| row_buffer = '' |
| for header, rows in cells: |
| cell = rows.get(line) |
| if cell is None: |
| row_buffer += pad('', len(header)) |
| else: |
| row_buffer += cell |
| outputs.append(row_buffer) |
| return '\n'.join(outputs) |
| |
| |
| class _ScriptProfile: |
| def __init__(self): |
| self.profile = classes.profiling._ScriptProfile() |
| |
| def enable(self): |
| self.profile.enable() |
| |
| def disable(self): |
| self.profile.disable() |
| |
| def dump_string(self) -> str: |
| outputs: List[str] = [] |
| for source_stats in self.profile._dump_stats(): |
| source_ref = source_stats.source() |
| source_lines = source_ref.text().splitlines() |
| dedent = min([len(line) - len(line.lstrip(' ')) for line in source_lines]) |
| source_lines = [line[dedent:] for line in source_lines] |
| |
| start_line = source_ref.starting_lineno() |
| end_line = start_line + len(source_lines) |
| source_range = range(start_line, end_line) |
| lineno = _ScriptProfileColumn("Line #") |
| hits = _ScriptProfileColumn("Hits") |
| time_ns = _ScriptProfileColumn("Time (ns)") |
| line_contents = _ScriptProfileColumn("Line Contents", 0, 1) |
| stats = source_stats.line_map() |
| for line in source_range: |
| lineno.add_row(line, line) |
| line_contents.add_row(line, source_lines[line - start_line]) |
| stat = stats.get(line) |
| if stat is not None: |
| hits.add_row(line, stat.count()) |
| time_ns.add_row(line, stat.duration_ns()) |
| |
| table = _ScriptProfileTable([lineno, hits, time_ns, line_contents], list(source_range)) |
| outputs.append(table.dump_string()) |
| return '\n\n'.join(outputs) |
| |
| def dump(self): |
| print(self.dump_string()) |
| |
| |
| def _unwrap_optional(x): |
| assert x is not None, "Unwrapping null optional" |
| return x |
| |
| |
| _register_builtin(_unwrap_optional, "aten::_unwrap_optional") |
| _register_builtin(_jit_internal.is_scripting, "aten::is_scripting") |
| _register_builtin(has_torch_function, "aten::has_torch_function") |
| _register_builtin(has_torch_function_unary, "aten::has_torch_function") |
| _register_builtin(has_torch_function_variadic, "aten::has_torch_function") |