[JIT] Propagate type sharing setting to submodule compilation (#44226)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44226

**Summary**
At present, the `share_types` argument to `create_script_module` is used
to decide whether to reuse a previously created type for a top-level
module that has not yet been compiled. However, that setting does not apply
to the compilation of submodules of the top-level module; types are
still reused if possible.

This commit modifies `create_script_module` so that the `share_types`
flag is honoured during submodule compilation as well.

**Test Plan**
This commit adds a unit test to `TestTypeSharing` that checks that
submodule types are not shared or reused when `share_types` is set to
`False`.

**Fixes**
This commit fixes #43605.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D23602371

Pulled By: SplitInfinity

fbshipit-source-id: b909b8b6abbe3b4cb9be8319ac263ade90e83bd3
diff --git a/test/jit/test_type_sharing.py b/test/jit/test_type_sharing.py
index 4919974..7981ed9 100644
--- a/test/jit/test_type_sharing.py
+++ b/test/jit/test_type_sharing.py
@@ -519,5 +519,44 @@
         one = A(1)
         two = A(2)
 
-        self.assertEquals(one(), 1)
-        self.assertEquals(two(), 2)
+        self.assertEqual(one(), 1)
+        self.assertEqual(two(), 2)
+
+    def test_type_sharing_disabled(self):
+        """
+        Test that type sharing can be disabled.
+        """
+        class A(torch.nn.Module):
+            def __init__(self, sub):
+                super().__init__()
+                self.sub = sub
+
+            def forward(self, x):
+                return x
+
+        class B(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return x
+
+        top1 = A(A(B()))
+        top2 = A(A(B()))
+
+        top1_s = torch.jit._recursive.create_script_module(
+            top1,
+            torch.jit._recursive.infer_methods_to_compile,
+            share_types=False,
+        )
+        top2_s = torch.jit._recursive.create_script_module(
+            top2,
+            torch.jit._recursive.infer_methods_to_compile,
+            share_types=False,
+        )
+
+        self.assertDifferentType(top1_s, top2_s)
+        self.assertDifferentType(top1_s, top1_s.sub)
+        self.assertDifferentType(top1_s, top2_s.sub)
+        self.assertDifferentType(top2_s, top2_s.sub)
+        self.assertDifferentType(top2_s, top1_s.sub)
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index 11523cb..8fad8d9 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -75,7 +75,7 @@
         super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
 
 
-def infer_concrete_type_builder(nn_module):
+def infer_concrete_type_builder(nn_module, share_types=True):
     """
     Build a ConcreteModuleTypeBuilder from an nn.Module. This
     ConcreteModuleType doesn't have a JIT type associated with it yet, it
@@ -136,7 +136,7 @@
             sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
         else:
             # otherwise we get the concrete module type for item and add it to concrete_type
-            sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
+            sub_concrete_type = get_module_concrete_type(item, share_types)
         concrete_type_builder.add_module(name, sub_concrete_type)
 
         added_names.add(name)
@@ -265,11 +265,6 @@
         Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
         types are re-used if possible.
         """
-        assert isinstance(nn_module, Module)
-        if isinstance(nn_module, torch.jit.ScriptModule) and \
-                hasattr(nn_module, "_concrete_type"):
-            return nn_module._concrete_type
-
         concrete_type_builder = infer_concrete_type_builder(nn_module)
 
         nn_module_type = type(nn_module)
@@ -295,6 +290,36 @@
     defaults = [get_default_args(m.original_method) for m in stubs]
     concrete_type._create_methods(defs, rcbs, defaults)
 
+def get_module_concrete_type(nn_module, share_types=True):
+    """
+    Gets a concrete type for nn_modules. If share_types is True, the concrete
+    type is fetched from concrete_type_store. If it is False, a new concrete type
+    is created without first searching concrete_type_store.
+
+    Arguments:
+        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
+        share_types = Whether to share underlying JIT types between modules (if possible).
+
+    Returns:
+        A concrete type for nn_module.
+    """
+    assert isinstance(nn_module, Module)
+    if isinstance(nn_module, torch.jit.ScriptModule) and \
+            hasattr(nn_module, "_concrete_type"):
+        return nn_module._concrete_type
+
+    if share_types:
+        # Look into the store of cached JIT types
+        concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
+    else:
+        # Get a concrete type directly, without trying to re-use an existing JIT
+        # type from the type store.
+        concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
+        concrete_type_builder.set_poisoned()
+        concrete_type = concrete_type_builder.build()
+
+    return concrete_type
+
 def create_script_module(nn_module, stubs_fn, share_types=True):
     """
     Creates a new ScriptModule from an nn.Module
@@ -309,15 +334,7 @@
     """
     assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
     check_module_initialized(nn_module)
-    if share_types:
-        # Look into the store of cached JIT types
-        concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
-    else:
-        # Get a concrete type directly, without trying to re-use an existing JIT
-        # type from the type store.
-        concrete_type_builder = infer_concrete_type_builder(nn_module)
-        concrete_type_builder.set_poisoned()
-        concrete_type = concrete_type_builder.build()
+    concrete_type = get_module_concrete_type(nn_module, share_types)
     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
 
 def create_script_module_impl(nn_module, concrete_type, stubs_fn):