[JIT] Add `__prepare_scriptable__` duck typing to allow replacing nn.modules with scriptable preparations (#45645) (#49242)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49242

Fixes https://github.com/pytorch/pytorch/issues/45072

As discussed with zdevito gchanan cpuhrsch and suo, this change allows developers to create custom preparations for their modules before scripting. This is done by adding a `__prepare_scriptable__` method to a module which returns the prepared scriptable module out-of-place. It does not expand the API surface for end users.

Prior art by jamesr66a: https://github.com/pytorch/pytorch/pull/42244

Test Plan: Imported from OSS

Reviewed By: dongreenberg

Differential Revision: D25500303

fbshipit-source-id: d3ec9005de27d8882fc29d02f0d08acd2a5c6b2c
diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py
index bd9a2bb..a0dc99a 100644
--- a/test/jit/test_recursive_script.py
+++ b/test/jit/test_recursive_script.py
@@ -495,6 +495,59 @@
 
         self.checkModule(M(), (torch.randn(5, 5),))
 
+    def test_prepare_scriptable_basic(self):
+        class SeluButReluWhenScripted(torch.nn.SELU):
+            def __prepare_scriptable__(self):
+                return nn.ReLU()
+
+        t = torch.randn(5, 5)
+        m = SeluButReluWhenScripted()
+        sm = torch.jit.script(m)
+        eager_out = m(t)
+        script_out = sm(t)
+        self.assertNotEqual(eager_out, script_out)
+
+    def test_prepare_scriptable_iterable_modules(self):
+        class SeluButReluWhenScripted(torch.nn.SELU):
+            def __prepare_scriptable__(self):
+                return nn.ReLU()
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                shared = SeluButReluWhenScripted()
+                self.sequential = nn.Sequential(
+                    SeluButReluWhenScripted(),
+                    SeluButReluWhenScripted(),
+                    nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
+                    shared,
+                )
+                self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
+                                                  shared,
+                                                  SeluButReluWhenScripted()])
+
+            def forward(self, x):
+                for mod in self.module_list:
+                    x += mod(x)
+                x += self.sequential(x)
+                return x
+
+        t = torch.randn(5, 5)
+        m = M()
+        eager_out = m(t.clone())
+        sm = torch.jit.script(m)
+        script_out = sm(t.clone())
+        self.assertNotEqual(eager_out, script_out)
+
+    def test_prepare_scriptable_cycle(self):
+        t = torch.randn(5, 5)
+        c = torch.nn.Module()
+        p = torch.nn.Module()
+        c.__dict__["_p"] = p
+        p.__dict__["_c"] = c
+
+        sm = torch.jit.script(p)
+
     def test_attributes(self):
         @torch.jit.script
         class Inner2(object):
diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py
index 31eec81..7f43b31 100644
--- a/test/jit/test_torchbind.py
+++ b/test/jit/test_torchbind.py
@@ -62,6 +62,32 @@
             return ss1.pop() + ss2.pop()
         test_equality(f, lambda x: x)
 
+        # test nn module with prepare_scriptable function
+        class NonJitableClass(object):
+            def __init__(self, int1, int2):
+                self.int1 = int1
+                self.int2 = int2
+
+            def return_vals(self):
+                return self.int1, self.int2
+
+        class CustomWrapper(torch.nn.Module):
+            def __init__(self, foo):
+                super(CustomWrapper, self).__init__()
+                self.foo = foo 
+
+            def forward(self) -> None:
+                self.foo.increment(1)
+                return
+
+            def __prepare_scriptable__(self):
+                int1, int2 = self.foo.return_vals()
+                foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
+                return CustomWrapper(foo) 
+
+        foo = CustomWrapper(NonJitableClass(1, 2))
+        jit_foo = torch.jit.script(foo)
+
     def test_torchbind_take_as_arg(self):
         global StackString  # see [local resolution in python]
         StackString = torch.classes._TorchScriptTesting._StackString
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index cc84877..8bc8c61 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -741,6 +741,43 @@
         def __init__(self, arg=None):
             super().__init__()
 
+def call_prepare_scriptable_func_impl(obj, memo):
+    if not isinstance(obj, torch.nn.Module):
+        return obj
+
+    obj_id = id(obj)
+
+    # If obj_id is in memo, obj has already been prepared or is being
+    # prepared in another call up the stack.
+    if obj_id in memo:
+        return memo[id(obj)]
+
+    obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj  # type: ignore
+    # Record obj in memo to avoid infinite recursion in the case of cycles in the module
+    # hierarchy when recursing below.
+    memo[obj_id] = obj
+
+    new_obj_dict = {}
+
+    for name in obj.__dict__:
+        sub_module = obj.__dict__.get(name)
+        if name == '_modules':
+            for k, v in sub_module.items():
+                sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
+            new_obj_dict[name] = sub_module
+        elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
+            new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
+        else:
+            new_obj_dict[name] = sub_module
+
+    for k, v in new_obj_dict.items():
+        obj.__dict__[name] = v
+
+    return obj
+
+def call_prepare_scriptable_func(obj):
+    memo: Dict[int, torch.nn.Module] = {}
+    return call_prepare_scriptable_func_impl(obj, memo)
 
 def script(obj, optimize=None, _frames_up=0, _rcb=None):
     r"""
@@ -894,6 +931,7 @@
         return obj
 
     if isinstance(obj, torch.nn.Module):
+        obj = call_prepare_scriptable_func(obj)
         return torch.jit._recursive.create_script_module(
             obj, torch.jit._recursive.infer_methods_to_compile
         )