[JIT] Add support for backend-lowered submodules (#41146)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41146

**Summary**
This commit adds support for using `Modules` that have been lowered as
submodules in `ScriptModules`.

**Test Plan**
This commit adds execution and save/load tests to test_backends.py for
backend-lowered submodules.

**Fixes**
This commit fixes #40069.

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D22459543

Pulled By: SplitInfinity

fbshipit-source-id: 02e0c0ccdce26c671ade30a34aca3e99bcdc5ba7
diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py
index aa85051..b421960 100644
--- a/test/jit/test_backends.py
+++ b/test/jit/test_backends.py
@@ -1,5 +1,4 @@
 from torch.testing._internal.jit_utils import JitTestCase
-import io
 import os
 import sys
 
@@ -11,10 +10,12 @@
 pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 sys.path.append(pytorch_test_dir)
 
-if __name__ == '__main__':
-    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
-                       "\tpython test/test_jit.py TESTNAME\n\n"
-                       "instead.")
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
 
 
 def to_test_backend(module, method_compile_spec):
@@ -25,9 +26,13 @@
     return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
 
 
-class MyModule(torch.nn.Module):
+class BasicModule(torch.nn.Module):
+    """
+    A simple Module used to test to_backend lowering machinery.
+    """
+
     def __init__(self):
-        super(MyModule, self).__init__()
+        super().__init__()
 
     def forward(self, x, h):
         return self.accum(x, h), self.sub_accum(x, h)
@@ -39,26 +44,28 @@
         return x - h
 
 
-class TestBackends(JitTestCase):
+class JitBackendTestCase(JitTestCase):
+    """
+    A common base class for JIT backend tests that contains common utility
+    functions for output comparison and serialization/deserialization.
+    """
+
     def setUp(self):
         super().setUp()
+        # Subclasses are expected to set up three variables in their setUp methods:
+        # module - a regular, Python version of the module being tested
+        # scripted_module - a scripted version of module
+        # lowered_modle - a version of module lowered to a backend
 
-        if not TEST_WITH_ROCM:
-            # Create Python, JIT and backend versions of MyModule.
-            self.module = MyModule()
-            self.scripted_module = torch.jit.script(MyModule())
-            self.lowered_module = to_test_backend_multi(
-                self.scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
-
-    def compare_py_jit_backend(self, name, input):
+    def check_function(self, function_name, input):
         """
-        This is a helper function for comparing the outputs of self.module (Python), self.scripted_module (JIT)
-        and self.lowered_module (backend) when the method named 'name' is invoked using 'input'.
+        Check that the function named 'function_name' produces the same output using
+        Python, regular JIT and the backend for the given 'input'.
         """
         # Get handles for Python, JIT and backend methods.
-        python_method = self.module.__getattribute__(name)
-        jit_method = self.scripted_module.__getattr__(name)
-        backend_method = self.lowered_module.__getattr__(name)
+        python_method = self.module.__getattribute__(function_name)
+        jit_method = self.scripted_module.__getattr__(function_name)
+        backend_method = self.lowered_module.__getattr__(function_name)
 
         # Run methods.
         python_output = python_method(input, input)
@@ -69,36 +76,47 @@
         self.assertEqual(python_output, backend_output)
         self.assertEqual(jit_output, backend_output)
 
-    @skipIfRocm
-    def test_simple(self):
+    def save_load(self):
         """
-        This is a simple test that compiles MyModule for the test backend and ensures it produces the correct
-        answers for each method.
+        Save and load the lowered module.
         """
+        self.lowered_module = self.getExportImportCopy(self.lowered_module)
+
+
+class BasicModuleTest(JitBackendTestCase):
+    """
+    Tests for BasicModule.
+    """
+
+    def setUp(self):
+        super().setUp()
+        # Create Python, JIT and backend versions of BasicModule.
+        self.module = BasicModule()
+        self.scripted_module = torch.jit.script(BasicModule())
+        self.lowered_module = to_test_backend_multi(
+            self.scripted_module._c,
+            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
+        )
+
+    def test_execution(self):
         # Test execution with backend against Python and JIT.
         input = torch.randn(5)
 
         # Test all three module methods.
-        self.compare_py_jit_backend("accum", input)
-        self.compare_py_jit_backend("sub_accum", input)
-        self.compare_py_jit_backend("forward", input)
+        self.check_function("accum", input)
+        self.check_function("sub_accum", input)
+        self.check_function("forward", input)
 
     @skipIfRocm
     def test_save_load(self):
-        """
-        This method tests that a lowered module till produces the same output as a Python module and ScriptModule after
-        saving and loading.
-        """
-        # Save the lowered module.
-        buffer = io.BytesIO()
-        torch.jit.save(self.lowered_module, buffer)
+        # Lowered module should produce the same outputs.
+        self.test_execution()
 
         # Save the compile spec to compare against the version retrieved after loading.
         pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
 
-        # Load the lowered module.
-        buffer.seek(0)
-        self.lowered_module = torch.jit.load(buffer)
+        # Save and load the lowered module.
+        self.save_load()
 
         # Get the compile spec after loading.
         post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
@@ -106,10 +124,82 @@
         # Compile specs should match.
         self.assertEqual(pre_compile_spec, post_compile_spec)
 
+        # Loaded module should produce the same outputs.
+        self.test_execution()
+
+
+class NestedModuleTest(JitBackendTestCase):
+    """
+    Tests for NestedModule that check that a module lowered to a backend can be used
+    as a submodule.
+    """
+    class NestedModule(torch.nn.Module):
+        """
+        A Module with one submodule that is used to test that lowered Modules
+        can be used as submodules.
+        """
+
+        def __init__(self, submodule):
+            super().__init__()
+            self.submodule = submodule
+
+        def forward(self, x, h):
+            return self.submodule.forward(x, h)
+
+    def setUp(self):
+        super().setUp()
+        # Create Python, JIT and backend versions of NestedModule.
+        # Both modules in self.module are regular Python modules.
+        self.module = NestedModuleTest.NestedModule(BasicModule())
+        # Both modules in self.scripted_module are ScriptModules.
+        self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule()))
+        lowered_module = to_test_backend_multi(
+            self.scripted_module._c, {"forward": {"": ""}}
+        )
+        # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
+        self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module))
+
+    def test_execution(self):
         # Test execution with backend against Python and JIT.
         input = torch.randn(5)
 
-        # Test all three module methods.
-        self.compare_py_jit_backend("accum", input)
-        self.compare_py_jit_backend("sub_accum", input)
-        self.compare_py_jit_backend("forward", input)
+        # Test forward.
+        self.check_function("forward", input)
+
+    def test_save_load(self):
+        # Lowered module should produce the same outputs.
+        self.test_execution()
+
+        # Save and load the lowered module.
+        self.save_load()
+
+        # Loaded module should produce the same outputs.
+        self.test_execution()
+
+
+class TestBackends(JitTestCase):
+    """
+    This class wraps and invokes all subclasses of JitBackendTestCase so that each one
+    does not have to be individually imported in test_jit.py.
+    """
+
+    def __init__(self, name):
+        super().__init__(name)
+        self.basic_module_test = BasicModuleTest(name)
+        self.nested_module_test = NestedModuleTest(name)
+
+    def setUp(self):
+        super().setUp()
+        if not TEST_WITH_ROCM:
+            self.basic_module_test.setUp()
+            self.nested_module_test.setUp()
+
+    @skipIfRocm
+    def test_execution(self):
+        self.basic_module_test.test_execution()
+        self.nested_module_test.test_execution()
+
+    @skipIfRocm
+    def test_save_load(self):
+        self.basic_module_test.test_save_load()
+        self.nested_module_test.test_save_load()
diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp
index 51e324b..b01cb62 100644
--- a/torch/csrc/jit/backends/backend_init.cpp
+++ b/torch/csrc/jit/backends/backend_init.cpp
@@ -17,115 +17,111 @@
   // this function must be called like
   //
   //  torch._C._jit_to_backend("example_backend", module, spec)
-  auto m = py::handle(module).cast<py::module>();
-  m.def(
-      "_jit_to_backend",
-      [=](const std::string& backend_name,
-          const Module& orig_module,
-          const py::dict& method_compile_spec) {
-        const c10::QualifiedName qual_backend_name({"__torch__",
-                                                    "torch",
-                                                    "classes",
-                                                    detail::kBackendsNamespace,
-                                                    backend_name});
-        // TODO: Validate method_compile_spec.
+  auto codegen_lambda = [=](const std::string& backend_name,
+                            const Module& orig_module,
+                            const py::dict& method_compile_spec) {
+    const c10::QualifiedName qual_backend_name({"__torch__",
+                                                "torch",
+                                                "classes",
+                                                detail::kBackendsNamespace,
+                                                backend_name});
+    // TODO: Validate method_compile_spec.
 
-        // Clone orig_module to make sure backend transformation is
-        // functional.
-        auto cloned_module = orig_module.clone();
+    // Clone orig_module to make sure backend transformation is
+    // functional.
+    auto cloned_module = orig_module.clone();
 
-        // Represents of a Type of Dict[str, Any].
-        auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
+    // Represents of a Type of Dict[str, Any].
+    auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
 
-        // Generate LoweredModule.
-        Module loweredModule(
-            "torch.jit." + backend_name + "LoweredModule",
-            get_python_cu(),
-            /*shouldMangle=*/true);
+    // Generate LoweredModule.
+    Module loweredModule(
+        "torch.jit." + backend_name + "LoweredModule",
+        get_python_cu(),
+        /*shouldMangle=*/true);
 
-        // Generate attributes.
-        // This is the original cloned and preprocessed module.
-        loweredModule.register_attribute(
-            "__processed_module",
-            AnyType::get(),
-            cloned_module._ivalue(),
-            /*is_param=*/false);
+    // Generate attributes.
+    // This is the original cloned and preprocessed module.
+    loweredModule.register_attribute(
+        "__processed_module",
+        AnyType::get(),
+        cloned_module._ivalue(),
+        /*is_param=*/false);
 
-        // This is for the method_compile_spec passed in to to_<backend> or
-        // loaded from an exported model.
-        loweredModule.register_attribute(
-            "__method_compile_spec",
-            any_dict_ty,
-            toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
-            /*is_param=*/false);
+    // This is for the method_compile_spec passed in to to_<backend> or
+    // loaded from an exported model.
+    loweredModule.register_attribute(
+        "__method_compile_spec",
+        any_dict_ty,
+        toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+        /*is_param=*/false);
 
-        // This is a pointer to a backend instance that is used to access
-        // compile and execute functions.
-        auto cls = getCustomClass(qual_backend_name.qualifiedName());
-        TORCH_INTERNAL_ASSERT(cls);
-        c10::intrusive_ptr<torch::CustomClassHolder> backend;
-        loweredModule.register_attribute(
-            "__backend", cls, IValue::make_capsule(backend));
+    // This is a pointer to a backend instance that is used to access
+    // compile and execute functions.
+    auto cls = getCustomClass(qual_backend_name.qualifiedName());
+    TORCH_INTERNAL_ASSERT(cls);
+    c10::intrusive_ptr<torch::CustomClassHolder> backend;
+    loweredModule.register_attribute(
+        "__backend", cls, IValue::make_capsule(backend));
 
-        // This is the list of opaque backend handles returned by
-        // backend.compile.
-        loweredModule.register_attribute(
-            "__handles",
-            any_dict_ty,
-            c10::impl::GenericDict(
-                any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
-            /*is_param=*/false);
+    // This is the list of opaque backend handles returned by
+    // backend.compile.
+    loweredModule.register_attribute(
+        "__handles",
+        any_dict_ty,
+        c10::impl::GenericDict(
+            any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
+        /*is_param=*/false);
 
-        // Methods.
+    // Methods.
 
-        // This is a helper function for creating a new instance of the
-        // backend class.
-        static const auto create_backend_ct = CodeTemplate(R"(
+    // This is a helper function for creating a new instance of the
+    // backend class.
+    static const auto create_backend_ct = CodeTemplate(R"(
             def __create_backend(self):
                 self.__backend = $name()
             )");
-        TemplateEnv create_backend_te;
-        create_backend_te.s("name", qual_backend_name.qualifiedName());
-        loweredModule.define(
-            create_backend_ct.format(create_backend_te),
-            loweredModuleResolver());
+    TemplateEnv create_backend_te;
+    create_backend_te.s("name", qual_backend_name.qualifiedName());
+    loweredModule.define(
+        create_backend_ct.format(create_backend_te), loweredModuleResolver());
 
-        // getstate and setstate are for serialization/deserialization of
-        // the LoweredModule.
-        loweredModule.define(
-            R"(
+    // getstate and setstate are for serialization/deserialization of
+    // the LoweredModule.
+    loweredModule.define(
+        R"(
             def __getstate__(self):
                 return self.__method_compile_spec, self.__processed_module
             )",
-            loweredModuleResolver());
+        loweredModuleResolver());
 
-        loweredModule.define(
-            R"(
+    loweredModule.define(
+        R"(
             def __setstate__(self, state):
                 self.__method_compile_spec = state[0]
                 self.__processed_module = state[1]
                 self.__create_backend()
                 self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
             )",
-            loweredModuleResolver());
+        loweredModuleResolver());
 
-        // This is never called during compilation or execution, but is
-        // needed to generate the LoweredModule because we don't have access
-        // to an instance of the backend as a C++ object with which to call
-        // preprocess.
-        loweredModule.define(
-            R"(
+    // This is never called during compilation or execution, but is
+    // needed to generate the LoweredModule because we don't have access
+    // to an instance of the backend as a C++ object with which to call
+    // preprocess.
+    loweredModule.define(
+        R"(
             def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
                 self.__create_backend()
                 self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
           )",
-            loweredModuleResolver());
+        loweredModuleResolver());
 
-        // This loop generates one method on the LoweredModule for every key
-        // in method_compile_spec.
-        for (auto& e : method_compile_spec) {
-          std::string method_name = py::cast<std::string>(e.first);
-          static const auto method_ct = CodeTemplate(R"(
+    // This loop generates one method on the LoweredModule for every key
+    // in method_compile_spec.
+    for (auto& e : method_compile_spec) {
+      std::string method_name = py::cast<std::string>(e.first);
+      static const auto method_ct = CodeTemplate(R"(
             def $method(self${,def_inputs}):
                 typed_inputs: List[Any] = [${fwd_inputs,}]
                 $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
@@ -133,98 +129,108 @@
                 return $ret
             )");
 
-          TemplateEnv method_te;
-          method_te.s("method", method_name);
-          auto method = orig_module.get_method(method_name);
-          auto& function = method.function();
-          auto& schema = function.getSchema();
+      TemplateEnv method_te;
+      method_te.s("method", method_name);
+      auto method = orig_module.get_method(method_name);
+      auto& function = method.function();
+      auto& schema = function.getSchema();
 
-          // Generate the inputs for the function signature (def_inputs) and
-          // for passing to backend.execute (fwd_inputs).
-          std::vector<std::string> def_inputs, fwd_inputs;
-          for (const auto& arg : schema.arguments()) {
-            auto name = arg.name();
+      // Generate the inputs for the function signature (def_inputs) and
+      // for passing to backend.execute (fwd_inputs).
+      std::vector<std::string> def_inputs, fwd_inputs;
+      for (const auto& arg : schema.arguments()) {
+        auto name = arg.name();
 
-            // Skip self since that is only and always present in the
-            // signature.
-            if (name == "self") {
-              continue;
-            }
-
-            auto default_value = arg.default_value();
-
-            if (arg.kwarg_only()) {
-              // If this is a kwarg, it needs to be emitted as keyword=value
-              // in the definition and keyword=keyword in the call to
-              // backend_execute.
-              TORCH_INTERNAL_ASSERT(default_value.has_value());
-              std::stringstream def_ss, fwd_ss;
-              def_ss << name << "=";
-              fwd_ss << name << "=" << name;
-              default_value->repr(
-                  def_ss,
-                  [](std::ostream&, const IValue&) -> bool { return false; });
-              def_inputs.emplace_back(def_ss.str());
-              fwd_inputs.emplace_back(fwd_ss.str());
-            } else {
-              // If this is not a kwarg, it should be emitted as is in the
-              // signature and the call to backend_execute.
-              def_inputs.emplace_back(name);
-              fwd_inputs.emplace_back(name);
-            }
-          }
-
-          // Generate a comma-delimited list of identifiers to unpack
-          // outputs, as well as a list of isinstance checks to make sure
-          // the backend returned the types it was supposed to.
-          std::stringstream out_ss, type_check_ss;
-          std::vector<std::string> type_checks;
-          TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
-          auto out_ty = schema.returns().at(0).type();
-
-          out_ss << "_0";
-          type_check_ss << "assert isinstance(_0, ";
-
-          if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
-            auto tuple_elements = out_tuple_ty->elements();
-            type_check_ss << tuple_elements[0]->str() << ")";
-            type_checks.emplace_back(type_check_ss.str());
-            for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
-              type_check_ss.str(std::string());
-              type_check_ss.clear();
-              out_ss << ", _" << i;
-              type_check_ss << "assert isinstance(_" << i << ", "
-                            << tuple_elements[i]->str() << ")";
-              type_checks.emplace_back(type_check_ss.str());
-            }
-          } else {
-            type_check_ss << out_ty->str() << ")";
-            type_checks.emplace_back(type_check_ss.str());
-          }
-
-          method_te.v("def_inputs", def_inputs);
-          method_te.v("fwd_inputs", fwd_inputs);
-          method_te.v("refine", type_checks);
-          method_te.s("ret", out_ss.str());
-
-          loweredModule.define(
-              method_ct.format(method_te), loweredModuleResolver());
+        // Skip self since that is only and always present in the
+        // signature.
+        if (name == "self") {
+          continue;
         }
 
-        // Run preprocess so that __processed_module is set correctly before
-        // compilation.
-        loweredModule.run_method(
-            "__preprocess",
-            cloned_module._ivalue(),
-            toIValue(method_compile_spec, any_dict_ty).toGenericDict());
+        auto default_value = arg.default_value();
 
-        // Call __setstate__ to ensure that the returned Module is ready to
-        // run.
-        auto state = at::ivalue::Tuple::create(
-            toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
-            loweredModule.attr("__processed_module"));
-        loweredModule.run_method("__setstate__", state);
-        return loweredModule;
+        if (arg.kwarg_only()) {
+          // If this is a kwarg, it needs to be emitted as keyword=value
+          // in the definition and keyword=keyword in the call to
+          // backend_execute.
+          TORCH_INTERNAL_ASSERT(default_value.has_value());
+          std::stringstream def_ss, fwd_ss;
+          def_ss << name << "=";
+          fwd_ss << name << "=" << name;
+          default_value->repr(def_ss, [](std::ostream&, const IValue&) -> bool {
+            return false;
+          });
+          def_inputs.emplace_back(def_ss.str());
+          fwd_inputs.emplace_back(fwd_ss.str());
+        } else {
+          // If this is not a kwarg, it should be emitted as is in the
+          // signature and the call to backend_execute.
+          def_inputs.emplace_back(name);
+          fwd_inputs.emplace_back(name);
+        }
+      }
+
+      // Generate a comma-delimited list of identifiers to unpack
+      // outputs, as well as a list of isinstance checks to make sure
+      // the backend returned the types it was supposed to.
+      std::stringstream out_ss, type_check_ss;
+      std::vector<std::string> type_checks;
+      TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
+      auto out_ty = schema.returns().at(0).type();
+
+      out_ss << "_0";
+      type_check_ss << "assert isinstance(_0, ";
+
+      if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
+        auto tuple_elements = out_tuple_ty->elements();
+        type_check_ss << tuple_elements[0]->str() << ")";
+        type_checks.emplace_back(type_check_ss.str());
+        for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
+          type_check_ss.str(std::string());
+          type_check_ss.clear();
+          out_ss << ", _" << i;
+          type_check_ss << "assert isinstance(_" << i << ", "
+                        << tuple_elements[i]->str() << ")";
+          type_checks.emplace_back(type_check_ss.str());
+        }
+      } else {
+        type_check_ss << out_ty->str() << ")";
+        type_checks.emplace_back(type_check_ss.str());
+      }
+
+      method_te.v("def_inputs", def_inputs);
+      method_te.v("fwd_inputs", fwd_inputs);
+      method_te.v("refine", type_checks);
+      method_te.s("ret", out_ss.str());
+
+      loweredModule.define(
+          method_ct.format(method_te), loweredModuleResolver());
+    }
+
+    // Run preprocess so that __processed_module is set correctly before
+    // compilation.
+    loweredModule.run_method(
+        "__preprocess",
+        cloned_module._ivalue(),
+        toIValue(method_compile_spec, any_dict_ty).toGenericDict());
+
+    // Call __setstate__ to ensure that the returned Module is ready to
+    // run.
+    auto state = at::ivalue::Tuple::create(
+        toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+        loweredModule.attr("__processed_module"));
+    loweredModule.run_method("__setstate__", state);
+    return loweredModule;
+  };
+  auto m = py::handle(module).cast<py::module>();
+  m.def(
+      "_jit_to_backend",
+      [=](const std::string& backend_name,
+          const Module& orig_module,
+          const py::dict& method_compile_spec) {
+        return py::module::import("torch.jit._recursive")
+            .attr("wrap_cpp_module")(
+                codegen_lambda(backend_name, orig_module, method_compile_spec));
       });
 }
 } // namespace jit
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 0a912b8..3e52a11 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -537,7 +537,8 @@
       }
       visited.emplace(item.a.internalToPointer());
     }
-    if (*unshapedType(item.a.type()) != *unshapedType(item.b.type())) {
+    if (!unshapedType(item.b.type())
+             ->isSubtypeOf(unshapedType(item.b.type()))) {
       // Since named types are saved and loaded in the test suite, we cannot
       // expect them to be equal. We should still check their slots however.
       if (!item.a.type()->cast<c10::NamedType>()) {