Delete `WeakScriptModuleProxy` (#23398)

Summary:
This PR deletes `WeakScriptModuleProxy` and uses `ScriptModule` directly and moves the recursive script stuff into `torch/jit/_recursive.py`. The first commit is just moving code, the latter 2 contain the actual changes
](https://our.intern.facebook.com/intern/diff/16712340/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23398

Pulled By: driazati

Reviewed By: eellison

Differential Revision: D16712340

fbshipit-source-id: f907efcec59bb2694c079ab655304324c125e9bb
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 8008041..01ff338 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -340,13 +340,13 @@
     // ScriptModule and add it as a submodule to the script::Module. This
     // enables lazy strong-ification of modules.
     auto result =
-        py::module::import("torch.jit")
-            .attr("_make_strong_submodule")(field, attr, py_module_);
+        py::module::import("torch.jit._recursive")
+            .attr("make_strong_submodule")(field, attr, py_module_);
     if (!result.is_none()) {
       auto submodule = as_module(result);
       TORCH_CHECK(
           submodule,
-          "Result of torch.jit._make_strong_submodule "
+          "Result of torch.torch.jit._recursive.make_strong_submodule "
           "was not a ScriptModule");
       // The module was a submodule of the nn.Module, so register it here
       // and return the submodule.
@@ -356,8 +356,8 @@
           m.graph()->insertGetAttr(self_, field), *v, result);
     }
   } else if (py::isinstance<py::function>(attr)) {
-    auto stub = py::module::import("torch.jit")
-                    .attr("_create_method_from_fn")(py_module_, attr);
+    auto stub = py::module::import("torch.jit._recursive")
+                    .attr("create_method_from_fn")(py_module_, attr);
     if (!stub.is_none()) {
       return SimpleValue(self_).attr(loc, m, field);
     }
@@ -585,7 +585,7 @@
     }
 
     auto compiled_fn =
-        py::module::import("torch.jit").attr("_try_compile_fn")(obj, loc);
+        py::module::import("torch.jit._recursive").attr("try_compile_fn")(obj, loc);
     if (auto callee = as_function(compiled_fn)) {
       return std::make_shared<FunctionValue>(*callee);
     }
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index aa120ac..74ed3a7 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1,32 +1,34 @@
 import torch._C
-from torch.autograd import Variable, function
-from torch.serialization import validate_cuda_device
-from torch.nn import Module, ModuleList, Parameter, Sequential
-from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
+import torch._jit_internal as _jit_internal
 import torch.backends.cudnn as cudnn
 import torch.jit.annotations
-import torch._jit_internal as _jit_internal
+import torch.testing
+import torch.jit._recursive
+
+
 from torch._jit_internal import _qualified_name
-from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
-    string_classes
+from torch.autograd import Variable, function
+from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
+from torch.nn import Module, ModuleList, Sequential
+from torch.serialization import validate_cuda_device
+from torch._six import PY2, PY37, with_metaclass, string_classes
 from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
     _list_with_default
-import torch.testing
 
-import math
-from collections import OrderedDict, namedtuple
-import textwrap
-import sys
-import warnings
-import weakref
-import types
-import contextlib
-import os
-import functools
-import copy
 import collections
+import contextlib
+import copy
+import functools
 import inspect
+import math
+import os
 import pickle
+import sys
+import textwrap
+import types
+import warnings
+
+from collections import OrderedDict, namedtuple
 
 # These are imported so users can access them from the `torch.jit` module
 from torch._jit_internal import Final, _overload  # noqa: F401
@@ -953,61 +955,6 @@
     pass
 
 
-def _create_constant_iterable_module(module):
-    modules = OrderedDict()
-
-    for key, submodule in module._modules.items():
-        if isinstance(submodule, (ModuleList, Sequential)):
-            # Make each item in the module a constant
-            modules[key] = _create_constant_iterable_module(submodule)
-        else:
-            modules[key] = _convert_to_script_module(submodule)
-
-    if isinstance(module, Sequential):
-        return _ConstSequential(Sequential(modules))
-    elif isinstance(module, ModuleList):
-        return _ConstModuleList(modules)
-    else:
-        raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
-                           "into constant modules, found {}".format(module))
-
-
-def _make_strong_submodule(field, module, parent):
-    if field not in parent._modules:
-        # It's not a submodule, don't do anything
-        return None
-
-    # Convert the module to a ScriptModule
-    new_strong_submodule = _convert_to_script_module(module)
-
-    # Install the ScriptModule on the python side
-    parent._modules._python_modules[field] = new_strong_submodule
-
-    return new_strong_submodule
-
-
-def _try_compile_fn(fn, loc):
-    if _jit_internal.is_ignored_fn(fn):
-        # Don't do anything for @ignore'd functions
-        return None
-
-    if isinstance(fn, torch.nn.Module):
-        # Since modules are callable pybind recognizes them as functions, but
-        # don't do anything for them
-        return None
-
-    if not inspect.isfunction(fn) and not inspect.ismethod(fn):
-        raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
-                           "Python functions or methods currently.\n"
-                           "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
-
-    # We don't have the actual scope where the function was defined, but we can
-    # extract the necessary info from the closed over variables on the function
-    # object
-    rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
-    return torch.jit.script(fn, _rcb=rcb)
-
-
 @contextlib.contextmanager
 def _disable_emit_hooks():
     hooks = torch._C._jit_get_emit_hooks()
@@ -1016,19 +963,6 @@
     torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
 
 
-def _create_method_from_fn(module, fn):
-    if _jit_internal.is_ignored_fn(fn):
-        return None
-    if not inspect.ismethod(fn):
-        return None
-    stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
-    with _disable_emit_hooks():
-        # We don't want to call the hooks here since the graph that is calling
-        # this function is not yet complete
-        _create_methods_from_stubs(module, (stub,))
-    return stub
-
-
 # ScriptClasses must be new-style classes because we construct them using their
 # __new__ method.
 def _is_new_style_class(cls):
@@ -1166,7 +1100,7 @@
         warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")
 
     if isinstance(obj, torch.nn.Module):
-        return _convert_to_script_module(obj)
+        return torch.jit.torch.jit._recursive.recursive_script(obj)
 
     qualified_name = _qualified_name(obj)
     if inspect.isclass(obj):
@@ -1673,167 +1607,12 @@
         def graph_for(self, *args, **kwargs):
             return self.forward.graph_for(*args, **kwargs)
 
-    class WeakScriptModuleProxy(ScriptModule):
-        # TODO: [weak script refactor]
-        # WeakScriptModule proxy should be deleted since its functionality is
-        # subsumed by recursive scripting, and the copying code in init moved
-        # to a function to create a ScriptModule from an nn.Module without
-        # making a WeakScriptModuleProxy
-        """
-        Copies the parameters, buffers, constants, attributes, and submodules
-        of an nn.Module into itself.
-        """
-        def __init__(self, original, stubs):
-            # Guards behavior of __setattr__ and __getattr__ so ScriptModule
-            # __init__ can run correctly
-            self.__dict__['_initialized'] = False
-            super(WeakScriptModuleProxy, self).__init__(_qualified_name=_qualified_name(type(original)))
-            # Store a weak reference to the original module
-            self.__dict__["_original"] = weakref.ref(original)
-
-            constants_set = set(getattr(original, "__constants__", []))
-            self.__dict__["_constants_set"] = {}
-
-            if not hasattr(original, '_parameters'):
-                raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
-                                   .format(type(original).__name__))
-
-            # Copy Parameters and Modules
-            for name in dir(original):
-                item = getattr(original, name)
-                if item is None and name in original._parameters:
-                    # XXX: treat None value simply as module attributes instead of adding them to the parameter list
-                    # TODO: need to handle this more generally when non-tensor attributes added to module
-                    object.__setattr__(self, name, item)
-                elif item is self:
-                    continue
-                elif isinstance(item, (Parameter, Module, Attribute)):
-                    ScriptModule.__setattr__(self, name, item)
-
-            # Copy buffers
-            for name in original._buffers:
-                if original._buffers[name] is None:
-                    object.__setattr__(self, name, None)
-                else:
-                    self.register_buffer(name, original._buffers[name])
-
-            # Constants annotated via `Final[T]` rather than being added to `__constants__`
-            for name, ann in getattr(original, '__annotations__', {}).items():
-                if torch._jit_internal.is_final(ann):
-                    constants_set.add(name)
-
-            # Copy constants
-            self.__dict__["_constants_set"] = constants_set
-            for name in self.__dict__["_constants_set"]:
-                if hasattr(original, name):
-                    if (name in original._parameters or name in original._buffers) and item is not None:
-                        # for 'None' parameters/buffers, don't actually add their values if it exists
-                        continue
-                    ScriptModule.__setattr__(self, name, getattr(original, name))
-
-            # Copy annotations, pull types from `__annotations__` or try to infer
-            # the type if possible
-            class_annotations = getattr(original, '__annotations__', {})
-            for name in dir(original):
-                if name in ("training", "__dict__"):
-                    # TODO: removing this skip should let us remove the code to add training as an
-                    # attribute in python_sugared_value.cpp
-                    continue
-                if hasattr(self, name):
-                    # Don't re-copy properties
-                    continue
-                item = getattr(original, name)
-                if name in class_annotations:
-                    the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
-                else:
-                    the_type = torch._C._jit_try_infer_type(item)
-                if the_type is not None:
-                    self._c._register_attribute(name, the_type, item)
-
-            # Copy overloads
-            self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
-
-            self.__dict__["_initialized"] = True
-            self.__dict__["_original_type"] = type(original)
-            _create_methods_from_stubs(self, stubs)
-
-        def __getattr__(self, attr):
-            # Try to get the attribute directly, if that fails, fall back to the
-            # weak module itself
-            try:
-                return ScriptModule.__getattr__(self, attr)
-            except AttributeError as e:
-                # unwrap the original
-                original_module = self.__dict__["_original"]()
-                if original_module and self.__dict__["_initialized"]:
-                    # get attr from original if it is still alive
-                    return getattr(original_module, attr)
-                elif self.__dict__["_initialized"]:
-                    # original module is dead, try looking up the value on the
-                    # original type
-                    fn = getattr(self.__dict__["_original_type"], attr, None)
-                    if fn is not None and inspect.isroutine(fn):
-                        # bind the function to this instance and return it
-                        return fn.__get__(self, self.__dict__["_original_type"])
-                # If it's not on this module and it wasn't on the original
-                # module (or the original is dead), throw the exception
-                raise e
-
-        def __setattr__(self, attr, value):
-            # Once constructed, no new properties can be set
-
-            if not self.__dict__["_initialized"]:
-                # If constructing, don't fall back to original module
-                return ScriptModule.__setattr__(self, attr, value)
-
-            if hasattr(self, attr):
-                return ScriptModule.__setattr__(self, attr, value)
-            else:
-                raise AttributeError("Cannot set new attribute '{}' on "
-                                     "weak script module once it has been "
-                                     "created".format(attr))
-
 else:
     class ScriptModule(torch.nn.Module):
         def __init__(self):
             super(ScriptModule, self).__init__()
 
 
-def _convert_to_script_module(mod):
-    """
-    Makes a ScriptModule from an nn.Module. If `_methods` is provided,
-    these methods are treated as @script_methods. If not, it defaults to
-    `('forward',)`. Methods accessed in forward are scripted on demand.
-    """
-    if isinstance(mod, ScriptModule):
-        return mod
-
-    if isinstance(mod, (ModuleList, Sequential)):
-        # Create constant versions for the iterable modules
-        return _create_constant_iterable_module(mod)
-
-    methods = ()
-    if hasattr(mod, 'forward'):
-        if mod.forward.__func__ == torch.nn.Module.forward:
-            raise RuntimeError("No forward method was defined on {}".format(mod))
-        if not _jit_internal.is_ignored_fn(mod.forward):
-            methods = ('forward',)
-    exported = []
-    for name in dir(mod):
-        item = getattr(mod, name)
-        if callable(item):
-            if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
-                exported.append(name)
-    methods = methods + tuple(exported)
-
-    def make_stub(method):
-        func = get_function_from_type(type(mod), method)
-        return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
-
-    stubs = list(map(make_stub, methods))
-    return WeakScriptModuleProxy(mod, stubs)
-
-
 def _get_methods(cls):
     import inspect
     # In Python 3 unbound methods are functions, but in Python 2 they are methods
@@ -1931,12 +1710,12 @@
         if isinstance(modules, OrderedDict):
             for key, module in modules.items():
                 if isinstance(module, torch.nn.Module):
-                    module = _convert_to_script_module(module)
+                    module = torch.jit._recursive.recursive_script(module)
                 self.add_module(key, module)
         else:
             for i, module in enumerate(modules):
                 if isinstance(module, torch.nn.Module):
-                    module = _convert_to_script_module(module)
+                    module = torch.jit._recursive.recursive_script(module)
                 self.add_module(str(i), module)
 
     def __getitem__(self, idx):
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
new file mode 100644
index 0000000..1be9806
--- /dev/null
+++ b/torch/jit/_recursive.py
@@ -0,0 +1,200 @@
+import inspect
+import torch
+import collections
+
+import torch._jit_internal as _jit_internal
+from torch.nn import Module, ModuleList, Parameter, Sequential
+from torch._six import get_function_from_type
+
+
+def copy_to_script_module(original, stubs):
+    """
+    Copies the parameters, buffers, constants, attributes, and submodules
+    of an nn.Module into itself.
+    """
+    if not hasattr(original, '_parameters'):
+        raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
+                           .format(type(original).__name__))
+
+    qualified_name = torch.jit._qualified_name(type(original))
+    script_module = torch.jit.ScriptModule(_qualified_name=qualified_name)
+
+    constants_set = set(getattr(original, "__constants__", []))
+    script_module.__dict__["_constants_set"] = {}
+
+    # Copy Parameters and Modules
+    for name in dir(original):
+        item = getattr(original, name)
+        if item is None and name in original._parameters:
+            # XXX: treat None value simply as module attributes instead of adding them to the parameter list
+            # TODO: need to handle this more generally when non-tensor attributes added to module
+            object.__setattr__(script_module, name, item)
+        elif item is script_module:
+            continue
+        elif isinstance(item, (Parameter, Module, torch.jit.Attribute)):
+            setattr(script_module, name, item)
+
+    # Copy buffers
+    for name in original._buffers:
+        if original._buffers[name] is None:
+            object.__setattr__(script_module, name, None)
+        else:
+            script_module.register_buffer(name, original._buffers[name])
+
+    # Constants annotated via `Final[T]` rather than being added to `__constants__`
+    for name, ann in getattr(original, '__annotations__', {}).items():
+        if torch._jit_internal.is_final(ann):
+            constants_set.add(name)
+
+    # Copy constants
+    script_module.__dict__["_constants_set"] = constants_set
+    for name in script_module.__dict__["_constants_set"]:
+        if hasattr(original, name):
+            if (name in original._parameters or name in original._buffers) and item is not None:
+                # for 'None' parameters/buffers, don't actually add their values if it exists
+                continue
+            setattr(script_module, name, getattr(original, name))
+
+    # Copy annotations, pull types from `__annotations__` or try to infer
+    # the type if possible
+    class_annotations = getattr(original, '__annotations__', {})
+    for name in dir(original):
+        if name in ("training", "__dict__"):
+            # TODO: removing this skip should let us remove the code to add training as an
+            # attribute in python_sugared_value.cpp
+            continue
+        if hasattr(script_module, name):
+            # Don't re-copy properties
+            continue
+        item = getattr(original, name)
+        if name in class_annotations:
+            the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
+        else:
+            the_type = torch._C._jit_try_infer_type(item)
+        if the_type is not None:
+            script_module._c._register_attribute(name, the_type, item)
+
+    # Copy overloads
+    script_module.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
+
+    # Copy links to Python methods so they can be resolved when compiling
+    for name in dir(original):
+        item = getattr(original, name)
+        if hasattr(script_module, name):
+            # Skip Python builtins and all the module methods that are already
+            # attached to this since it inherits from nn.Module
+            continue
+        if inspect.ismethod(item):
+            setattr(script_module, name, item)
+
+    torch.jit._create_methods_from_stubs(script_module, stubs)
+
+    # Now that methods have been compiled, take methods that have been compiled
+    # and have them shadow their corresponding Python functions
+    for method_name in script_module._c._method_names():
+        setattr(script_module, method_name, script_module._c._get_method(method_name))
+
+    return script_module
+
+
+def recursive_script(mod):
+    """
+    Makes a ScriptModule from an nn.Module. If `_methods` is provided,
+    these methods are treated as @script_methods. If not, it defaults to
+    `('forward',)`. Methods accessed in forward are scripted on demand.
+    """
+    if isinstance(mod, torch.jit.ScriptModule):
+        return mod
+
+    if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)):
+        # Create constant versions for the iterable modules
+        return create_constant_iterable_module(mod)
+
+    methods = ()
+    if hasattr(mod, 'forward'):
+        if mod.forward.__func__ == torch.nn.Module.forward:
+            raise RuntimeError("No forward method was defined on {}".format(mod))
+        if not _jit_internal.is_ignored_fn(mod.forward):
+            methods = ('forward',)
+    exported = []
+    for name in dir(mod):
+        item = getattr(mod, name)
+        if callable(item):
+            if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
+                exported.append(name)
+    methods = methods + tuple(exported)
+
+    def make_stub(method):
+        func = get_function_from_type(type(mod), method)
+        return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
+
+    stubs = list(map(make_stub, methods))
+    return copy_to_script_module(mod, stubs)
+
+
+def create_method_from_fn(module, fn):
+    if _jit_internal.is_ignored_fn(fn):
+        return None
+    if not inspect.ismethod(fn):
+        return None
+    stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
+    with torch.jit._disable_emit_hooks():
+        # We don't want to call the hooks here since the graph that is calling
+        # this function is not yet complete
+        torch.jit._create_methods_from_stubs(module, (stub,))
+    return stub
+
+
+def make_strong_submodule(field, module, parent):
+    if field not in parent._modules:
+        # It's not a submodule, don't do anything
+        return None
+
+    # Convert the module to a ScriptModule
+    new_strong_submodule = recursive_script(module)
+
+    # Install the ScriptModule on the python side
+    parent._modules._python_modules[field] = new_strong_submodule
+
+    return new_strong_submodule
+
+
+def try_compile_fn(fn, loc):
+    if _jit_internal.is_ignored_fn(fn):
+        # Don't do anything for @ignore'd functions
+        return None
+
+    if isinstance(fn, torch.nn.Module):
+        # Since modules are callable pybind recognizes them as functions, but
+        # don't do anything for them
+        return None
+
+    if not inspect.isfunction(fn) and not inspect.ismethod(fn):
+        raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
+                           "Python functions or methods currently.\n"
+                           "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
+
+    # We don't have the actual scope where the function was defined, but we can
+    # extract the necessary info from the closed over variables on the function
+    # object
+    rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
+    return torch.jit.script(fn, _rcb=rcb)
+
+
+def create_constant_iterable_module(module):
+    modules = collections.OrderedDict()
+
+    for key, submodule in module._modules.items():
+        if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+            # Make each item in the module a constant
+            modules[key] = create_constant_iterable_module(submodule)
+        else:
+            modules[key] = recursive_script(submodule)
+
+    if isinstance(module, Sequential):
+        return torch.jit._ConstSequential(Sequential(modules))
+    elif isinstance(module, ModuleList):
+        return torch.jit._ConstModuleList(modules)
+    else:
+        raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
+                           "into constant modules, found {}".format(module))