| import inspect |
| import torch |
| import collections |
| |
| import torch._jit_internal as _jit_internal |
| from torch.nn import Module, ModuleList, Parameter, Sequential |
| from torch._six import get_function_from_type |
| |
| |
| def copy_to_script_module(original, stubs): |
| """ |
| Copies the parameters, buffers, constants, attributes, and submodules |
| of an nn.Module into itself. |
| """ |
| qualified_name = torch.jit._qualified_name(type(original)) |
| script_module = torch.jit.ScriptModule(_qualified_name=qualified_name) |
| |
| constants_set = set(getattr(original, "__constants__", [])) |
| script_module.__dict__["_constants_set"] = {} |
| |
| |
| # Copy Parameters and Modules |
| for name in dir(original): |
| item = getattr(original, name) |
| if item is None and name in original._parameters: |
| # XXX: treat None value simply as module attributes instead of adding them to the parameter list |
| # TODO: need to handle this more generally when non-tensor attributes added to module |
| object.__setattr__(script_module, name, item) |
| elif item is script_module: |
| continue |
| elif isinstance(item, (Parameter, Module, torch.jit.Attribute)): |
| setattr(script_module, name, item) |
| |
| # Copy buffers |
| for name in original._buffers: |
| if original._buffers[name] is None: |
| object.__setattr__(script_module, name, None) |
| else: |
| script_module.register_buffer(name, original._buffers[name]) |
| |
| # Constants annotated via `Final[T]` rather than being added to `__constants__` |
| for name, ann in getattr(original, '__annotations__', {}).items(): |
| if torch._jit_internal.is_final(ann): |
| constants_set.add(name) |
| |
| # Copy constants |
| script_module.__dict__["_constants_set"] = constants_set |
| for name in script_module.__dict__["_constants_set"]: |
| if hasattr(original, name): |
| if (name in original._parameters or name in original._buffers) and item is not None: |
| # for 'None' parameters/buffers, don't actually add their values if it exists |
| continue |
| # don't recopy constants, should only occur for constant modules/params |
| if not hasattr(script_module, name): |
| setattr(script_module, name, getattr(original, name)) |
| |
| # Copy annotations, pull types from `__annotations__` or try to infer |
| # the type if possible |
| class_annotations = getattr(original, '__annotations__', {}) |
| for name in dir(original): |
| if name in ("training", "__dict__"): |
| # TODO: removing this skip should let us remove the code to add training as an |
| # attribute in python_sugared_value.cpp |
| continue |
| if hasattr(script_module, name): |
| # Don't re-copy properties |
| continue |
| item = getattr(original, name) |
| if name in class_annotations: |
| the_type = torch.jit.annotations.ann_to_type(class_annotations[name]) |
| else: |
| the_type = torch._C._jit_try_infer_type(item) |
| if the_type is not None: |
| script_module._c._register_attribute(name, the_type, item) |
| |
| # Copy overloads |
| script_module.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {})) |
| |
| # Copy links to Python methods so they can be resolved when compiling |
| for name in dir(original): |
| item = getattr(original, name) |
| if hasattr(script_module, name): |
| # Skip Python builtins and all the module methods that are already |
| # attached to this since it inherits from nn.Module |
| continue |
| if inspect.ismethod(item): |
| setattr(script_module, name, item) |
| |
| torch.jit._create_methods_from_stubs(script_module, stubs) |
| |
| # Now that methods have been compiled, take methods that have been compiled |
| # and have them shadow their corresponding Python functions |
| for method_name in script_module._c._method_names(): |
| setattr(script_module, method_name, script_module._c._get_method(method_name)) |
| |
| return script_module |
| |
| |
| def recursive_script(mod, exclude_methods=()): |
| """ |
| Makes a ScriptModule from an nn.Module. If `_methods` is provided, |
| these methods are treated as @script_methods. If not, it defaults to |
| `('forward',)`. Methods accessed in forward are scripted on demand. |
| """ |
| if isinstance(mod, torch.jit.ScriptModule): |
| return mod |
| |
| if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)): |
| # Create constant versions for the iterable modules |
| return create_constant_iterable_module(mod) |
| |
| if not hasattr(mod, '_parameters'): |
| raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" |
| .format(type(mod).__name__)) |
| |
| methods = () |
| if hasattr(mod, 'forward'): |
| if mod.forward.__func__ == torch.nn.Module.forward: |
| raise RuntimeError("No forward method was defined on {}".format(mod)) |
| if not _jit_internal.is_ignored_fn(mod.forward): |
| methods = ('forward',) |
| exported = [] |
| overloads = [] |
| for name in dir(mod): |
| item = getattr(mod, name) |
| if callable(item): |
| if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: |
| exported.append(name) |
| |
| # builtin functions like repr() in python 2 do not have __module__ defined |
| if hasattr(item, "__module__") and item.__module__ is not None: |
| method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) |
| |
| if method_overloads is not None: |
| overloads.append((item, method_overloads)) |
| |
| methods = methods + tuple(exported) |
| |
| methods = tuple(name for name in methods if name not in exclude_methods) |
| |
| overload_name_mappings = dict(getattr(mod, "__overloads__", {})) |
| overload_stubs = [] |
| |
| for orig_fn, overload_fns in overloads: |
| orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule") |
| names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns)))) |
| overload_name_mappings[orig_ast.name().name] = names |
| for overload_fn, name in zip(overload_fns, names): |
| torch.jit._check_no_signature(overload_fn) |
| over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule") |
| new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name) |
| _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) |
| overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn)) |
| |
| mod.__overloads__ = overload_name_mappings |
| |
| # we shouldn't directly compile overloaded methods, just its overloads |
| def ignore_overloaded(method_name): |
| return method_name not in overload_name_mappings |
| |
| def make_stub(method): |
| func = get_function_from_type(type(mod), method) |
| return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) |
| |
| filtered_methods = filter(ignore_overloaded, methods) |
| stubs = list(map(make_stub, filtered_methods)) |
| return copy_to_script_module(mod, overload_stubs + stubs) |
| |
| |
| def create_method_from_fn(module, fn): |
| if _jit_internal.is_ignored_fn(fn): |
| return None |
| if not inspect.ismethod(fn): |
| return None |
| stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn)) |
| with torch.jit._disable_emit_hooks(): |
| # We don't want to call the hooks here since the graph that is calling |
| # this function is not yet complete |
| torch.jit._create_methods_from_stubs(module, (stub,)) |
| return stub |
| |
| |
| def make_strong_submodule(field, module, parent): |
| if field not in parent._modules: |
| # It's not a submodule, don't do anything |
| return None |
| |
| # Convert the module to a ScriptModule |
| new_strong_submodule = recursive_script(module) |
| |
| # Install the ScriptModule on the python side |
| parent._modules._python_modules[field] = new_strong_submodule |
| |
| return new_strong_submodule |
| |
| |
| def try_compile_fn(fn, loc): |
| if _jit_internal.is_ignored_fn(fn): |
| # Don't do anything for @ignore'd functions |
| return None |
| |
| if isinstance(fn, torch.nn.Module): |
| # Since modules are callable pybind recognizes them as functions, but |
| # don't do anything for them |
| return None |
| |
| if not inspect.isfunction(fn) and not inspect.ismethod(fn): |
| raise RuntimeError("`{}` is not a function. Recursive scripting only supports " |
| "Python functions or methods currently.\n" |
| "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn)) |
| |
| # We don't have the actual scope where the function was defined, but we can |
| # extract the necessary info from the closed over variables on the function |
| # object |
| rcb = _jit_internal.createResolutionCallbackFromClosure(fn) |
| return torch.jit.script(fn, _rcb=rcb) |
| |
| |
| def create_constant_iterable_module(module): |
| modules = collections.OrderedDict() |
| |
| for key, submodule in module._modules.items(): |
| if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): |
| # Make each item in the module a constant |
| modules[key] = create_constant_iterable_module(submodule) |
| else: |
| modules[key] = recursive_script(submodule) |
| |
| if isinstance(module, Sequential): |
| return torch.jit._ConstSequential(Sequential(modules)) |
| elif isinstance(module, ModuleList): |
| return torch.jit._ConstModuleList(modules) |
| else: |
| raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made " |
| "into constant modules, found {}".format(module)) |