[JIT] Propagate type sharing setting to submodule compilation (#44226)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44226
**Summary**
At present, the `share_types` argument to `create_script_module` is used
to decide whether to reuse a previously created type for a top-level
module that has not yet been compiled. However, that setting does not apply
to the compilation of submodules of the top-level module; types are
still reused if possible.
This commit modifies `create_script_module` so that the `share_types`
flag is honoured during submodule compilation as well.
**Test Plan**
This commit adds a unit test to `TestTypeSharing` that checks that
submodule types are not shared or reused when `share_types` is set to
`False`.
**Fixes**
This commit fixes #43605.
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D23602371
Pulled By: SplitInfinity
fbshipit-source-id: b909b8b6abbe3b4cb9be8319ac263ade90e83bd3
diff --git a/test/jit/test_type_sharing.py b/test/jit/test_type_sharing.py
index 4919974..7981ed9 100644
--- a/test/jit/test_type_sharing.py
+++ b/test/jit/test_type_sharing.py
@@ -519,5 +519,44 @@
one = A(1)
two = A(2)
- self.assertEquals(one(), 1)
- self.assertEquals(two(), 2)
+ self.assertEqual(one(), 1)
+ self.assertEqual(two(), 2)
+
+ def test_type_sharing_disabled(self):
+ """
+ Test that type sharing can be disabled.
+ """
+ class A(torch.nn.Module):
+ def __init__(self, sub):
+ super().__init__()
+ self.sub = sub
+
+ def forward(self, x):
+ return x
+
+ class B(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+ top1 = A(A(B()))
+ top2 = A(A(B()))
+
+ top1_s = torch.jit._recursive.create_script_module(
+ top1,
+ torch.jit._recursive.infer_methods_to_compile,
+ share_types=False,
+ )
+ top2_s = torch.jit._recursive.create_script_module(
+ top2,
+ torch.jit._recursive.infer_methods_to_compile,
+ share_types=False,
+ )
+
+ self.assertDifferentType(top1_s, top2_s)
+ self.assertDifferentType(top1_s, top1_s.sub)
+ self.assertDifferentType(top1_s, top2_s.sub)
+ self.assertDifferentType(top2_s, top2_s.sub)
+ self.assertDifferentType(top2_s, top1_s.sub)
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index 11523cb..8fad8d9 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -75,7 +75,7 @@
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
-def infer_concrete_type_builder(nn_module):
+def infer_concrete_type_builder(nn_module, share_types=True):
"""
Build a ConcreteModuleTypeBuilder from an nn.Module. This
ConcreteModuleType doesn't have a JIT type associated with it yet, it
@@ -136,7 +136,7 @@
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
else:
# otherwise we get the concrete module type for item and add it to concrete_type
- sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
+ sub_concrete_type = get_module_concrete_type(item, share_types)
concrete_type_builder.add_module(name, sub_concrete_type)
added_names.add(name)
@@ -265,11 +265,6 @@
Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
types are re-used if possible.
"""
- assert isinstance(nn_module, Module)
- if isinstance(nn_module, torch.jit.ScriptModule) and \
- hasattr(nn_module, "_concrete_type"):
- return nn_module._concrete_type
-
concrete_type_builder = infer_concrete_type_builder(nn_module)
nn_module_type = type(nn_module)
@@ -295,6 +290,36 @@
defaults = [get_default_args(m.original_method) for m in stubs]
concrete_type._create_methods(defs, rcbs, defaults)
+def get_module_concrete_type(nn_module, share_types=True):
+ """
+ Gets a concrete type for nn_modules. If share_types is True, the concrete
+ type is fetched from concrete_type_store. If it is False, a new concrete type
+ is created without first searching concrete_type_store.
+
+ Arguments:
+ nn_module: The original Python nn.Module that we are creating a ScriptModule for.
+ share_types = Whether to share underlying JIT types between modules (if possible).
+
+ Returns:
+ A concrete type for nn_module.
+ """
+ assert isinstance(nn_module, Module)
+ if isinstance(nn_module, torch.jit.ScriptModule) and \
+ hasattr(nn_module, "_concrete_type"):
+ return nn_module._concrete_type
+
+ if share_types:
+ # Look into the store of cached JIT types
+ concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
+ else:
+ # Get a concrete type directly, without trying to re-use an existing JIT
+ # type from the type store.
+ concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
+ concrete_type_builder.set_poisoned()
+ concrete_type = concrete_type_builder.build()
+
+ return concrete_type
+
def create_script_module(nn_module, stubs_fn, share_types=True):
"""
Creates a new ScriptModule from an nn.Module
@@ -309,15 +334,7 @@
"""
assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
check_module_initialized(nn_module)
- if share_types:
- # Look into the store of cached JIT types
- concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
- else:
- # Get a concrete type directly, without trying to re-use an existing JIT
- # type from the type store.
- concrete_type_builder = infer_concrete_type_builder(nn_module)
- concrete_type_builder.set_poisoned()
- concrete_type = concrete_type_builder.build()
+ concrete_type = get_module_concrete_type(nn_module, share_types)
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
def create_script_module_impl(nn_module, concrete_type, stubs_fn):