| """ |
| The weak_script annotation needs to be here instead of inside torch/jit/ so it |
| can be used in other places in torch/ (namely torch.nn) without running into |
| circular dependency problems |
| """ |
| |
| import weakref |
| import inspect |
| from torch._six import builtins |
| |
| # Tracks standalone weak script functions |
| compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484 |
| |
| # Tracks which methods should be converted to strong methods |
| weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484 |
| |
| # Converted modules and their corresponding WeakScriptModuleProxy objects |
| weak_modules = weakref.WeakKeyDictionary() # noqa: T484 |
| |
| # Types that have been declared as weak modules |
| weak_types = weakref.WeakKeyDictionary() # noqa: T484 |
| |
| # Wrapper functions that can call either of 2 functions depending on a boolean |
| # argument |
| boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484 |
| |
| COMPILATION_PENDING = object() |
| COMPILED = object() |
| |
| |
| def createResolutionCallback(frames_up=0): |
| """ |
| Creates a function which, given a string variable name, |
| returns the value of the variable in the scope of the caller of |
| the function which called createResolutionCallback (by default). |
| |
| This is used to enable access in-scope Python variables inside |
| TorchScript fragments. |
| |
| frames_up is number of additional frames to go up on the stack. |
| The default value is 0, which correspond to the frame of the caller |
| of createResolutionCallback. Also for example, if frames_up is set |
| to 1, then the frame of the caller's caller of createResolutionCallback |
| will be taken. |
| |
| For example, the following program prints 2:: |
| |
| def bar(): |
| cb = createResolutionCallback(1) |
| print(cb("foo")) |
| |
| def baz(): |
| foo = 2 |
| bar() |
| |
| baz() |
| """ |
| frame = inspect.currentframe() |
| i = 0 |
| while i < frames_up + 1: |
| frame = frame.f_back |
| i += 1 |
| |
| f_locals = frame.f_locals |
| f_globals = frame.f_globals |
| |
| def env(key): |
| if key in f_locals: |
| return f_locals[key] |
| elif key in f_globals: |
| return f_globals[key] |
| elif hasattr(builtins, key): |
| return getattr(builtins, key) |
| else: |
| return None |
| |
| return env |
| |
| |
| def weak_script(fn, _frames_up=0): |
| """ |
| Marks a function as a weak script function. When used in a script function |
| or ScriptModule, the weak script function will be lazily compiled and |
| inlined in the graph. When not used in a script function, the weak script |
| annotation has no effect. |
| """ |
| compiled_weak_fns[fn] = { |
| "status": COMPILATION_PENDING, |
| "compiled_fn": None, |
| "rcb": createResolutionCallback(_frames_up + 1) |
| } |
| return fn |
| |
| |
| def weak_module(cls): |
| weak_types[cls] = { |
| "method_stubs": None |
| } |
| return cls |
| |
| |
| def weak_script_method(fn): |
| weak_script_methods[fn] = { |
| "rcb": createResolutionCallback(frames_up=2), |
| "original_method": fn |
| } |
| return fn |
| |
| |
| def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name): |
| """ |
| Dispatches to either of 2 weak script functions based on a boolean argument. |
| In TorchScript, the boolean argument must be constant so that the correct |
| function to use can be determined at compile time. |
| """ |
| if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None: |
| raise RuntimeError("both functions must be weak script") |
| |
| def fn(*args, **kwargs): |
| dispatch_flag = False |
| if arg_name in kwargs: |
| dispatch_flag = kwargs[arg_name] |
| elif arg_index < len(args): |
| dispatch_flag = args[arg_index] |
| |
| if dispatch_flag: |
| return if_true(*args, **kwargs) |
| else: |
| return if_false(*args, **kwargs) |
| |
| if if_true.__doc__ is None and if_false.__doc__ is not None: |
| doc = if_false.__doc__ |
| if_true.__doc__ = doc |
| elif if_false.__doc__ is None and if_true.__doc__ is not None: |
| doc = if_true.__doc__ |
| if_false.__doc__ = doc |
| elif if_false.__doc__ is None and if_true.__doc__ is None: |
| # neither function has a docstring |
| doc = None |
| else: |
| raise RuntimeError("only one function can have a docstring") |
| fn.__doc__ = doc |
| |
| if module_name is not None: |
| fn.__module__ = module_name |
| if func_name is not None: |
| fn.__name__ = func_name |
| |
| boolean_dispatched[fn] = { |
| "if_true": if_true, |
| "if_false": if_false, |
| "index": arg_index, |
| "default": default, |
| "arg_name": arg_name |
| } |
| return fn |
| |
| |
| |
| class FunctionModifiers(object): |
| """ |
| Used to denote the behavior of a function in TorchScript. See export() and |
| ignore() for details. |
| """ |
| IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)" |
| IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" |
| EXPORT = "export (compile this function even if nothing calls it)" |
| DEFAULT = "default (compile if called from a exported function / forward)" |
| |
| |
| def export(fn): |
| """ |
| This decorator indicates that a method is used as an entry point into a |
| ScriptModule. `forward` implicitly is used as an entry point, so it does |
| not need this decorator. |
| |
| Methods are added to a ScriptModule as they are called in Python. If a |
| method is never called, it will not be included in the ScriptModule when |
| saving. This decorator explicitly marks that a method should be included |
| even if it is not called from Python. |
| """ |
| fn._torchscript_modifier = FunctionModifiers.EXPORT |
| return fn |
| |
| |
| def ignore(maybe_fn=None, *, drop_on_export=False): |
| """ |
| This decorator indicates to the compiler that a function or method should |
| be ignored and left as a Python function. |
| |
| With `drop_on_export=False` (the default), calls to this function will |
| prevent saving a TorchScript model. |
| |
| With `drop_on_export=True`, any calls to this function from other |
| TorchScript code will be replaced with a `raise`. This allows you to leave |
| code in your TorchScript model that is only ever run when the Python |
| interpreter is present. |
| """ |
| if maybe_fn is None: |
| # No positional args passed, so the decorator as been used with a kwarg, |
| # like @torch.jit.ignore(drop_on_export=True) |
| def decorator(fn): |
| if drop_on_export: |
| fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP |
| else: |
| fn._torchscript_modifier = FunctionModifiers.IGNORE |
| return fn |
| return decorator |
| |
| if callable(maybe_fn): |
| # used without any args, so drop_on_export is actually a function |
| # @torch.jit.ignore |
| # def fn(...): |
| maybe_fn._torchscript_modifier = FunctionModifiers.IGNORE |
| return maybe_fn |
| else: |
| if isinstance(maybe_fn, bool): |
| correct_usage = "@torch.jit.ignore(drop_on_export={})".format("True" if maybe_fn else "False") |
| raise RuntimeError("drop_on_export must be used as a kwarg, e.g. " |
| "'{}' ".format(correct_usage)) |
| raise RuntimeError("Argument to @torch.jit.ignore must be a bool or " |
| "a function but got {}".format(maybe_fn)) |
| |
| |
| def should_drop_on_export(fn): |
| attr = get_torchscript_modifier(fn) |
| if attr is None: |
| return False |
| return attr is FunctionModifiers.IGNORE_AND_DROP |
| |
| |
| def is_ignored_fn(fn): |
| mod = get_torchscript_modifier(fn) |
| return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE |
| |
| |
| def get_torchscript_modifier(fn): |
| if not callable(fn): |
| return None |
| if hasattr(fn, '__func__'): |
| fn = fn.__func__ |
| return getattr(fn, '_torchscript_modifier', FunctionModifiers.DEFAULT) |
| |
| |
| def _parameter_list(parameter_names_fn): |
| """ |
| Decorator to denote that a function returns a list of all the parameters |
| in a module |
| """ |
| def decorator(fn): |
| fn._parameter_names_fn = parameter_names_fn |
| return fn |
| |
| return decorator |
| |
| |
| try: |
| import typing |
| from typing import Tuple, List, Dict, Optional |
| |
| def is_tuple(ann): |
| # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule |
| return ann.__module__ == 'typing' and \ |
| (getattr(ann, '__origin__', None) is typing.Tuple or |
| getattr(ann, '__origin__', None) is tuple) |
| |
| def is_list(ann): |
| return ann.__module__ == 'typing' and \ |
| (getattr(ann, '__origin__', None) is typing.List or |
| getattr(ann, '__origin__', None) is list) |
| |
| def is_dict(ann): |
| return ann.__module__ == 'typing' and \ |
| (getattr(ann, '__origin__', None) is typing.Dict or |
| getattr(ann, '__origin__', None) is dict) |
| |
| def is_optional(ann): |
| # Optional[T] is just shorthand for Union[T, None], so check for both |
| union_optional = False |
| if ann.__module__ == 'typing' and \ |
| (getattr(ann, '__origin__', None) is typing.Union): |
| args = getattr(ann, '__args__', ()) |
| if len(args) == 2: |
| union_optional = (issubclass(args[1], type(None)) and not issubclass(args[0], type(None))) \ |
| or (issubclass(args[0], type(None)) and not issubclass(args[1], type(None))) |
| |
| optional = ann.__module__ == 'typing' and \ |
| (getattr(ann, '__origin__', None) is typing.Optional) |
| |
| return optional or union_optional |
| |
| except ImportError: |
| # A minimal polyfill for versions of Python that don't have typing. |
| # Note that this means that they also don't support the fancy annotation syntax, so |
| # those instances will only be used in our tiny `type: ` comment interpreter. |
| |
| # The __getitem__ in typing is implemented using metaclasses, but I'm too lazy for that. |
| class TupleCls(object): |
| def __getitem__(self, types): |
| return TupleInstance(types) |
| |
| class TupleInstance(object): |
| __slots__ = ['__args__'] |
| |
| def __init__(self, types): |
| self.__args__ = types |
| |
| class ListInstance(object): |
| __slots__ = ['__args__'] |
| |
| def __init__(self, types): |
| self.__args__ = types |
| |
| class ListCls(object): |
| def __getitem__(self, types): |
| return TupleInstance(types) |
| |
| class DictInstance(object): |
| __slots__ = ['__args__'] |
| |
| def __init__(self, types): |
| self.__args__ = types |
| |
| class DictCls(object): |
| def __getitem__(self, types): |
| return DictInstance(types) |
| |
| class OptionalInstance(object): |
| __slots__ = ['__args__'] |
| |
| def __init__(self, types): |
| self.__args__ = types |
| |
| class OptionalCls(object): |
| def __getitem__(self, types): |
| return OptionalInstance(types) |
| |
| Tuple = TupleCls() # noqa: T484 |
| List = ListCls() # noqa: T484 |
| Dict = DictCls() # noqa: T484 |
| Optional = DictCls() # noqa: T484 |
| |
| def is_tuple(ann): |
| return isinstance(ann, TupleInstance) |
| |
| def is_list(ann): |
| return isinstance(ann, ListInstance) |
| |
| def is_dict(ann): |
| return isinstance(ann, DictInstance) |
| |
| def is_optional(ann): |
| return isinstance(ann, OptionalInstance) |
| |
| |
| # allows BroadcastingList instance to be subscriptable |
| class BroadcastingListCls(object): |
| def __getitem__(self, types): |
| return |
| |
| # mypy doesn't support parameters on types, so we have to explicitly type each |
| # list size |
| BroadcastingList1 = BroadcastingListCls() |
| for i in range(2, 7): |
| globals()["BroadcastingList{}".format(i)] = BroadcastingList1 |