[JIT] Add `__prepare_scriptable__` duck typing to allow replacing nn.modules with scriptable preparations (#45645) (#49242)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49242
Fixes https://github.com/pytorch/pytorch/issues/45072
As discussed with zdevito gchanan cpuhrsch and suo, this change allows developers to create custom preparations for their modules before scripting. This is done by adding a `__prepare_scriptable__` method to a module which returns the prepared scriptable module out-of-place. It does not expand the API surface for end users.
Prior art by jamesr66a: https://github.com/pytorch/pytorch/pull/42244
Test Plan: Imported from OSS
Reviewed By: dongreenberg
Differential Revision: D25500303
fbshipit-source-id: d3ec9005de27d8882fc29d02f0d08acd2a5c6b2c
diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py
index bd9a2bb..a0dc99a 100644
--- a/test/jit/test_recursive_script.py
+++ b/test/jit/test_recursive_script.py
@@ -495,6 +495,59 @@
self.checkModule(M(), (torch.randn(5, 5),))
+ def test_prepare_scriptable_basic(self):
+ class SeluButReluWhenScripted(torch.nn.SELU):
+ def __prepare_scriptable__(self):
+ return nn.ReLU()
+
+ t = torch.randn(5, 5)
+ m = SeluButReluWhenScripted()
+ sm = torch.jit.script(m)
+ eager_out = m(t)
+ script_out = sm(t)
+ self.assertNotEqual(eager_out, script_out)
+
+ def test_prepare_scriptable_iterable_modules(self):
+ class SeluButReluWhenScripted(torch.nn.SELU):
+ def __prepare_scriptable__(self):
+ return nn.ReLU()
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ shared = SeluButReluWhenScripted()
+ self.sequential = nn.Sequential(
+ SeluButReluWhenScripted(),
+ SeluButReluWhenScripted(),
+ nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
+ shared,
+ )
+ self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
+ shared,
+ SeluButReluWhenScripted()])
+
+ def forward(self, x):
+ for mod in self.module_list:
+ x += mod(x)
+ x += self.sequential(x)
+ return x
+
+ t = torch.randn(5, 5)
+ m = M()
+ eager_out = m(t.clone())
+ sm = torch.jit.script(m)
+ script_out = sm(t.clone())
+ self.assertNotEqual(eager_out, script_out)
+
+ def test_prepare_scriptable_cycle(self):
+ t = torch.randn(5, 5)
+ c = torch.nn.Module()
+ p = torch.nn.Module()
+ c.__dict__["_p"] = p
+ p.__dict__["_c"] = c
+
+ sm = torch.jit.script(p)
+
def test_attributes(self):
@torch.jit.script
class Inner2(object):
diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py
index 31eec81..7f43b31 100644
--- a/test/jit/test_torchbind.py
+++ b/test/jit/test_torchbind.py
@@ -62,6 +62,32 @@
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
+ # test nn module with prepare_scriptable function
+ class NonJitableClass(object):
+ def __init__(self, int1, int2):
+ self.int1 = int1
+ self.int2 = int2
+
+ def return_vals(self):
+ return self.int1, self.int2
+
+ class CustomWrapper(torch.nn.Module):
+ def __init__(self, foo):
+ super(CustomWrapper, self).__init__()
+ self.foo = foo
+
+ def forward(self) -> None:
+ self.foo.increment(1)
+ return
+
+ def __prepare_scriptable__(self):
+ int1, int2 = self.foo.return_vals()
+ foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
+ return CustomWrapper(foo)
+
+ foo = CustomWrapper(NonJitableClass(1, 2))
+ jit_foo = torch.jit.script(foo)
+
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index cc84877..8bc8c61 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -741,6 +741,43 @@
def __init__(self, arg=None):
super().__init__()
+def call_prepare_scriptable_func_impl(obj, memo):
+ if not isinstance(obj, torch.nn.Module):
+ return obj
+
+ obj_id = id(obj)
+
+ # If obj_id is in memo, obj has already been prepared or is being
+ # prepared in another call up the stack.
+ if obj_id in memo:
+ return memo[id(obj)]
+
+ obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore
+ # Record obj in memo to avoid infinite recursion in the case of cycles in the module
+ # hierarchy when recursing below.
+ memo[obj_id] = obj
+
+ new_obj_dict = {}
+
+ for name in obj.__dict__:
+ sub_module = obj.__dict__.get(name)
+ if name == '_modules':
+ for k, v in sub_module.items():
+ sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
+ new_obj_dict[name] = sub_module
+ elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
+ new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
+ else:
+ new_obj_dict[name] = sub_module
+
+ for k, v in new_obj_dict.items():
+ obj.__dict__[name] = v
+
+ return obj
+
+def call_prepare_scriptable_func(obj):
+ memo: Dict[int, torch.nn.Module] = {}
+ return call_prepare_scriptable_func_impl(obj, memo)
def script(obj, optimize=None, _frames_up=0, _rcb=None):
r"""
@@ -894,6 +931,7 @@
return obj
if isinstance(obj, torch.nn.Module):
+ obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(
obj, torch.jit._recursive.infer_methods_to_compile
)