Delete `WeakScriptModuleProxy` (#23398)
Summary:
This PR deletes `WeakScriptModuleProxy` and uses `ScriptModule` directly and moves the recursive script stuff into `torch/jit/_recursive.py`. The first commit is just moving code, the latter 2 contain the actual changes
](https://our.intern.facebook.com/intern/diff/16712340/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23398
Pulled By: driazati
Reviewed By: eellison
Differential Revision: D16712340
fbshipit-source-id: f907efcec59bb2694c079ab655304324c125e9bb
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 8008041..01ff338 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -340,13 +340,13 @@
// ScriptModule and add it as a submodule to the script::Module. This
// enables lazy strong-ification of modules.
auto result =
- py::module::import("torch.jit")
- .attr("_make_strong_submodule")(field, attr, py_module_);
+ py::module::import("torch.jit._recursive")
+ .attr("make_strong_submodule")(field, attr, py_module_);
if (!result.is_none()) {
auto submodule = as_module(result);
TORCH_CHECK(
submodule,
- "Result of torch.jit._make_strong_submodule "
+ "Result of torch.torch.jit._recursive.make_strong_submodule "
"was not a ScriptModule");
// The module was a submodule of the nn.Module, so register it here
// and return the submodule.
@@ -356,8 +356,8 @@
m.graph()->insertGetAttr(self_, field), *v, result);
}
} else if (py::isinstance<py::function>(attr)) {
- auto stub = py::module::import("torch.jit")
- .attr("_create_method_from_fn")(py_module_, attr);
+ auto stub = py::module::import("torch.jit._recursive")
+ .attr("create_method_from_fn")(py_module_, attr);
if (!stub.is_none()) {
return SimpleValue(self_).attr(loc, m, field);
}
@@ -585,7 +585,7 @@
}
auto compiled_fn =
- py::module::import("torch.jit").attr("_try_compile_fn")(obj, loc);
+ py::module::import("torch.jit._recursive").attr("try_compile_fn")(obj, loc);
if (auto callee = as_function(compiled_fn)) {
return std::make_shared<FunctionValue>(*callee);
}
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index aa120ac..74ed3a7 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1,32 +1,34 @@
import torch._C
-from torch.autograd import Variable, function
-from torch.serialization import validate_cuda_device
-from torch.nn import Module, ModuleList, Parameter, Sequential
-from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
+import torch._jit_internal as _jit_internal
import torch.backends.cudnn as cudnn
import torch.jit.annotations
-import torch._jit_internal as _jit_internal
+import torch.testing
+import torch.jit._recursive
+
+
from torch._jit_internal import _qualified_name
-from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
- string_classes
+from torch.autograd import Variable, function
+from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
+from torch.nn import Module, ModuleList, Sequential
+from torch.serialization import validate_cuda_device
+from torch._six import PY2, PY37, with_metaclass, string_classes
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default
-import torch.testing
-import math
-from collections import OrderedDict, namedtuple
-import textwrap
-import sys
-import warnings
-import weakref
-import types
-import contextlib
-import os
-import functools
-import copy
import collections
+import contextlib
+import copy
+import functools
import inspect
+import math
+import os
import pickle
+import sys
+import textwrap
+import types
+import warnings
+
+from collections import OrderedDict, namedtuple
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload # noqa: F401
@@ -953,61 +955,6 @@
pass
-def _create_constant_iterable_module(module):
- modules = OrderedDict()
-
- for key, submodule in module._modules.items():
- if isinstance(submodule, (ModuleList, Sequential)):
- # Make each item in the module a constant
- modules[key] = _create_constant_iterable_module(submodule)
- else:
- modules[key] = _convert_to_script_module(submodule)
-
- if isinstance(module, Sequential):
- return _ConstSequential(Sequential(modules))
- elif isinstance(module, ModuleList):
- return _ConstModuleList(modules)
- else:
- raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
- "into constant modules, found {}".format(module))
-
-
-def _make_strong_submodule(field, module, parent):
- if field not in parent._modules:
- # It's not a submodule, don't do anything
- return None
-
- # Convert the module to a ScriptModule
- new_strong_submodule = _convert_to_script_module(module)
-
- # Install the ScriptModule on the python side
- parent._modules._python_modules[field] = new_strong_submodule
-
- return new_strong_submodule
-
-
-def _try_compile_fn(fn, loc):
- if _jit_internal.is_ignored_fn(fn):
- # Don't do anything for @ignore'd functions
- return None
-
- if isinstance(fn, torch.nn.Module):
- # Since modules are callable pybind recognizes them as functions, but
- # don't do anything for them
- return None
-
- if not inspect.isfunction(fn) and not inspect.ismethod(fn):
- raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
- "Python functions or methods currently.\n"
- "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
-
- # We don't have the actual scope where the function was defined, but we can
- # extract the necessary info from the closed over variables on the function
- # object
- rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
- return torch.jit.script(fn, _rcb=rcb)
-
-
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
@@ -1016,19 +963,6 @@
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
-def _create_method_from_fn(module, fn):
- if _jit_internal.is_ignored_fn(fn):
- return None
- if not inspect.ismethod(fn):
- return None
- stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
- with _disable_emit_hooks():
- # We don't want to call the hooks here since the graph that is calling
- # this function is not yet complete
- _create_methods_from_stubs(module, (stub,))
- return stub
-
-
# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.
def _is_new_style_class(cls):
@@ -1166,7 +1100,7 @@
warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")
if isinstance(obj, torch.nn.Module):
- return _convert_to_script_module(obj)
+ return torch.jit.torch.jit._recursive.recursive_script(obj)
qualified_name = _qualified_name(obj)
if inspect.isclass(obj):
@@ -1673,167 +1607,12 @@
def graph_for(self, *args, **kwargs):
return self.forward.graph_for(*args, **kwargs)
- class WeakScriptModuleProxy(ScriptModule):
- # TODO: [weak script refactor]
- # WeakScriptModule proxy should be deleted since its functionality is
- # subsumed by recursive scripting, and the copying code in init moved
- # to a function to create a ScriptModule from an nn.Module without
- # making a WeakScriptModuleProxy
- """
- Copies the parameters, buffers, constants, attributes, and submodules
- of an nn.Module into itself.
- """
- def __init__(self, original, stubs):
- # Guards behavior of __setattr__ and __getattr__ so ScriptModule
- # __init__ can run correctly
- self.__dict__['_initialized'] = False
- super(WeakScriptModuleProxy, self).__init__(_qualified_name=_qualified_name(type(original)))
- # Store a weak reference to the original module
- self.__dict__["_original"] = weakref.ref(original)
-
- constants_set = set(getattr(original, "__constants__", []))
- self.__dict__["_constants_set"] = {}
-
- if not hasattr(original, '_parameters'):
- raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
- .format(type(original).__name__))
-
- # Copy Parameters and Modules
- for name in dir(original):
- item = getattr(original, name)
- if item is None and name in original._parameters:
- # XXX: treat None value simply as module attributes instead of adding them to the parameter list
- # TODO: need to handle this more generally when non-tensor attributes added to module
- object.__setattr__(self, name, item)
- elif item is self:
- continue
- elif isinstance(item, (Parameter, Module, Attribute)):
- ScriptModule.__setattr__(self, name, item)
-
- # Copy buffers
- for name in original._buffers:
- if original._buffers[name] is None:
- object.__setattr__(self, name, None)
- else:
- self.register_buffer(name, original._buffers[name])
-
- # Constants annotated via `Final[T]` rather than being added to `__constants__`
- for name, ann in getattr(original, '__annotations__', {}).items():
- if torch._jit_internal.is_final(ann):
- constants_set.add(name)
-
- # Copy constants
- self.__dict__["_constants_set"] = constants_set
- for name in self.__dict__["_constants_set"]:
- if hasattr(original, name):
- if (name in original._parameters or name in original._buffers) and item is not None:
- # for 'None' parameters/buffers, don't actually add their values if it exists
- continue
- ScriptModule.__setattr__(self, name, getattr(original, name))
-
- # Copy annotations, pull types from `__annotations__` or try to infer
- # the type if possible
- class_annotations = getattr(original, '__annotations__', {})
- for name in dir(original):
- if name in ("training", "__dict__"):
- # TODO: removing this skip should let us remove the code to add training as an
- # attribute in python_sugared_value.cpp
- continue
- if hasattr(self, name):
- # Don't re-copy properties
- continue
- item = getattr(original, name)
- if name in class_annotations:
- the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
- else:
- the_type = torch._C._jit_try_infer_type(item)
- if the_type is not None:
- self._c._register_attribute(name, the_type, item)
-
- # Copy overloads
- self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
-
- self.__dict__["_initialized"] = True
- self.__dict__["_original_type"] = type(original)
- _create_methods_from_stubs(self, stubs)
-
- def __getattr__(self, attr):
- # Try to get the attribute directly, if that fails, fall back to the
- # weak module itself
- try:
- return ScriptModule.__getattr__(self, attr)
- except AttributeError as e:
- # unwrap the original
- original_module = self.__dict__["_original"]()
- if original_module and self.__dict__["_initialized"]:
- # get attr from original if it is still alive
- return getattr(original_module, attr)
- elif self.__dict__["_initialized"]:
- # original module is dead, try looking up the value on the
- # original type
- fn = getattr(self.__dict__["_original_type"], attr, None)
- if fn is not None and inspect.isroutine(fn):
- # bind the function to this instance and return it
- return fn.__get__(self, self.__dict__["_original_type"])
- # If it's not on this module and it wasn't on the original
- # module (or the original is dead), throw the exception
- raise e
-
- def __setattr__(self, attr, value):
- # Once constructed, no new properties can be set
-
- if not self.__dict__["_initialized"]:
- # If constructing, don't fall back to original module
- return ScriptModule.__setattr__(self, attr, value)
-
- if hasattr(self, attr):
- return ScriptModule.__setattr__(self, attr, value)
- else:
- raise AttributeError("Cannot set new attribute '{}' on "
- "weak script module once it has been "
- "created".format(attr))
-
else:
class ScriptModule(torch.nn.Module):
def __init__(self):
super(ScriptModule, self).__init__()
-def _convert_to_script_module(mod):
- """
- Makes a ScriptModule from an nn.Module. If `_methods` is provided,
- these methods are treated as @script_methods. If not, it defaults to
- `('forward',)`. Methods accessed in forward are scripted on demand.
- """
- if isinstance(mod, ScriptModule):
- return mod
-
- if isinstance(mod, (ModuleList, Sequential)):
- # Create constant versions for the iterable modules
- return _create_constant_iterable_module(mod)
-
- methods = ()
- if hasattr(mod, 'forward'):
- if mod.forward.__func__ == torch.nn.Module.forward:
- raise RuntimeError("No forward method was defined on {}".format(mod))
- if not _jit_internal.is_ignored_fn(mod.forward):
- methods = ('forward',)
- exported = []
- for name in dir(mod):
- item = getattr(mod, name)
- if callable(item):
- if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
- exported.append(name)
- methods = methods + tuple(exported)
-
- def make_stub(method):
- func = get_function_from_type(type(mod), method)
- return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
-
- stubs = list(map(make_stub, methods))
- return WeakScriptModuleProxy(mod, stubs)
-
-
def _get_methods(cls):
import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods
@@ -1931,12 +1710,12 @@
if isinstance(modules, OrderedDict):
for key, module in modules.items():
if isinstance(module, torch.nn.Module):
- module = _convert_to_script_module(module)
+ module = torch.jit._recursive.recursive_script(module)
self.add_module(key, module)
else:
for i, module in enumerate(modules):
if isinstance(module, torch.nn.Module):
- module = _convert_to_script_module(module)
+ module = torch.jit._recursive.recursive_script(module)
self.add_module(str(i), module)
def __getitem__(self, idx):
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
new file mode 100644
index 0000000..1be9806
--- /dev/null
+++ b/torch/jit/_recursive.py
@@ -0,0 +1,200 @@
+import inspect
+import torch
+import collections
+
+import torch._jit_internal as _jit_internal
+from torch.nn import Module, ModuleList, Parameter, Sequential
+from torch._six import get_function_from_type
+
+
+def copy_to_script_module(original, stubs):
+ """
+ Copies the parameters, buffers, constants, attributes, and submodules
+ of an nn.Module into itself.
+ """
+ if not hasattr(original, '_parameters'):
+ raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
+ .format(type(original).__name__))
+
+ qualified_name = torch.jit._qualified_name(type(original))
+ script_module = torch.jit.ScriptModule(_qualified_name=qualified_name)
+
+ constants_set = set(getattr(original, "__constants__", []))
+ script_module.__dict__["_constants_set"] = {}
+
+ # Copy Parameters and Modules
+ for name in dir(original):
+ item = getattr(original, name)
+ if item is None and name in original._parameters:
+ # XXX: treat None value simply as module attributes instead of adding them to the parameter list
+ # TODO: need to handle this more generally when non-tensor attributes added to module
+ object.__setattr__(script_module, name, item)
+ elif item is script_module:
+ continue
+ elif isinstance(item, (Parameter, Module, torch.jit.Attribute)):
+ setattr(script_module, name, item)
+
+ # Copy buffers
+ for name in original._buffers:
+ if original._buffers[name] is None:
+ object.__setattr__(script_module, name, None)
+ else:
+ script_module.register_buffer(name, original._buffers[name])
+
+ # Constants annotated via `Final[T]` rather than being added to `__constants__`
+ for name, ann in getattr(original, '__annotations__', {}).items():
+ if torch._jit_internal.is_final(ann):
+ constants_set.add(name)
+
+ # Copy constants
+ script_module.__dict__["_constants_set"] = constants_set
+ for name in script_module.__dict__["_constants_set"]:
+ if hasattr(original, name):
+ if (name in original._parameters or name in original._buffers) and item is not None:
+ # for 'None' parameters/buffers, don't actually add their values if it exists
+ continue
+ setattr(script_module, name, getattr(original, name))
+
+ # Copy annotations, pull types from `__annotations__` or try to infer
+ # the type if possible
+ class_annotations = getattr(original, '__annotations__', {})
+ for name in dir(original):
+ if name in ("training", "__dict__"):
+ # TODO: removing this skip should let us remove the code to add training as an
+ # attribute in python_sugared_value.cpp
+ continue
+ if hasattr(script_module, name):
+ # Don't re-copy properties
+ continue
+ item = getattr(original, name)
+ if name in class_annotations:
+ the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
+ else:
+ the_type = torch._C._jit_try_infer_type(item)
+ if the_type is not None:
+ script_module._c._register_attribute(name, the_type, item)
+
+ # Copy overloads
+ script_module.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
+
+ # Copy links to Python methods so they can be resolved when compiling
+ for name in dir(original):
+ item = getattr(original, name)
+ if hasattr(script_module, name):
+ # Skip Python builtins and all the module methods that are already
+ # attached to this since it inherits from nn.Module
+ continue
+ if inspect.ismethod(item):
+ setattr(script_module, name, item)
+
+ torch.jit._create_methods_from_stubs(script_module, stubs)
+
+ # Now that methods have been compiled, take methods that have been compiled
+ # and have them shadow their corresponding Python functions
+ for method_name in script_module._c._method_names():
+ setattr(script_module, method_name, script_module._c._get_method(method_name))
+
+ return script_module
+
+
+def recursive_script(mod):
+ """
+ Makes a ScriptModule from an nn.Module. If `_methods` is provided,
+ these methods are treated as @script_methods. If not, it defaults to
+ `('forward',)`. Methods accessed in forward are scripted on demand.
+ """
+ if isinstance(mod, torch.jit.ScriptModule):
+ return mod
+
+ if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)):
+ # Create constant versions for the iterable modules
+ return create_constant_iterable_module(mod)
+
+ methods = ()
+ if hasattr(mod, 'forward'):
+ if mod.forward.__func__ == torch.nn.Module.forward:
+ raise RuntimeError("No forward method was defined on {}".format(mod))
+ if not _jit_internal.is_ignored_fn(mod.forward):
+ methods = ('forward',)
+ exported = []
+ for name in dir(mod):
+ item = getattr(mod, name)
+ if callable(item):
+ if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
+ exported.append(name)
+ methods = methods + tuple(exported)
+
+ def make_stub(method):
+ func = get_function_from_type(type(mod), method)
+ return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
+
+ stubs = list(map(make_stub, methods))
+ return copy_to_script_module(mod, stubs)
+
+
+def create_method_from_fn(module, fn):
+ if _jit_internal.is_ignored_fn(fn):
+ return None
+ if not inspect.ismethod(fn):
+ return None
+ stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
+ with torch.jit._disable_emit_hooks():
+ # We don't want to call the hooks here since the graph that is calling
+ # this function is not yet complete
+ torch.jit._create_methods_from_stubs(module, (stub,))
+ return stub
+
+
+def make_strong_submodule(field, module, parent):
+ if field not in parent._modules:
+ # It's not a submodule, don't do anything
+ return None
+
+ # Convert the module to a ScriptModule
+ new_strong_submodule = recursive_script(module)
+
+ # Install the ScriptModule on the python side
+ parent._modules._python_modules[field] = new_strong_submodule
+
+ return new_strong_submodule
+
+
+def try_compile_fn(fn, loc):
+ if _jit_internal.is_ignored_fn(fn):
+ # Don't do anything for @ignore'd functions
+ return None
+
+ if isinstance(fn, torch.nn.Module):
+ # Since modules are callable pybind recognizes them as functions, but
+ # don't do anything for them
+ return None
+
+ if not inspect.isfunction(fn) and not inspect.ismethod(fn):
+ raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
+ "Python functions or methods currently.\n"
+ "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
+
+ # We don't have the actual scope where the function was defined, but we can
+ # extract the necessary info from the closed over variables on the function
+ # object
+ rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
+ return torch.jit.script(fn, _rcb=rcb)
+
+
+def create_constant_iterable_module(module):
+ modules = collections.OrderedDict()
+
+ for key, submodule in module._modules.items():
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ # Make each item in the module a constant
+ modules[key] = create_constant_iterable_module(submodule)
+ else:
+ modules[key] = recursive_script(submodule)
+
+ if isinstance(module, Sequential):
+ return torch.jit._ConstSequential(Sequential(modules))
+ elif isinstance(module, ModuleList):
+ return torch.jit._ConstModuleList(modules)
+ else:
+ raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
+ "into constant modules, found {}".format(module))