[JIT] Add support for backend-lowered submodules (#41146)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41146
**Summary**
This commit adds support for using `Modules` that have been lowered as
submodules in `ScriptModules`.
**Test Plan**
This commit adds execution and save/load tests to test_backends.py for
backend-lowered submodules.
**Fixes**
This commit fixes #40069.
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D22459543
Pulled By: SplitInfinity
fbshipit-source-id: 02e0c0ccdce26c671ade30a34aca3e99bcdc5ba7
diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py
index aa85051..b421960 100644
--- a/test/jit/test_backends.py
+++ b/test/jit/test_backends.py
@@ -1,5 +1,4 @@
from torch.testing._internal.jit_utils import JitTestCase
-import io
import os
import sys
@@ -11,10 +10,12 @@
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
-if __name__ == '__main__':
- raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
- "\tpython test/test_jit.py TESTNAME\n\n"
- "instead.")
+if __name__ == "__main__":
+ raise RuntimeError(
+ "This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_jit.py TESTNAME\n\n"
+ "instead."
+ )
def to_test_backend(module, method_compile_spec):
@@ -25,9 +26,13 @@
return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
-class MyModule(torch.nn.Module):
+class BasicModule(torch.nn.Module):
+ """
+ A simple Module used to test to_backend lowering machinery.
+ """
+
def __init__(self):
- super(MyModule, self).__init__()
+ super().__init__()
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
@@ -39,26 +44,28 @@
return x - h
-class TestBackends(JitTestCase):
+class JitBackendTestCase(JitTestCase):
+ """
+ A common base class for JIT backend tests that contains common utility
+ functions for output comparison and serialization/deserialization.
+ """
+
def setUp(self):
super().setUp()
+ # Subclasses are expected to set up three variables in their setUp methods:
+ # module - a regular, Python version of the module being tested
+ # scripted_module - a scripted version of module
+ # lowered_modle - a version of module lowered to a backend
- if not TEST_WITH_ROCM:
- # Create Python, JIT and backend versions of MyModule.
- self.module = MyModule()
- self.scripted_module = torch.jit.script(MyModule())
- self.lowered_module = to_test_backend_multi(
- self.scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
-
- def compare_py_jit_backend(self, name, input):
+ def check_function(self, function_name, input):
"""
- This is a helper function for comparing the outputs of self.module (Python), self.scripted_module (JIT)
- and self.lowered_module (backend) when the method named 'name' is invoked using 'input'.
+ Check that the function named 'function_name' produces the same output using
+ Python, regular JIT and the backend for the given 'input'.
"""
# Get handles for Python, JIT and backend methods.
- python_method = self.module.__getattribute__(name)
- jit_method = self.scripted_module.__getattr__(name)
- backend_method = self.lowered_module.__getattr__(name)
+ python_method = self.module.__getattribute__(function_name)
+ jit_method = self.scripted_module.__getattr__(function_name)
+ backend_method = self.lowered_module.__getattr__(function_name)
# Run methods.
python_output = python_method(input, input)
@@ -69,36 +76,47 @@
self.assertEqual(python_output, backend_output)
self.assertEqual(jit_output, backend_output)
- @skipIfRocm
- def test_simple(self):
+ def save_load(self):
"""
- This is a simple test that compiles MyModule for the test backend and ensures it produces the correct
- answers for each method.
+ Save and load the lowered module.
"""
+ self.lowered_module = self.getExportImportCopy(self.lowered_module)
+
+
+class BasicModuleTest(JitBackendTestCase):
+ """
+ Tests for BasicModule.
+ """
+
+ def setUp(self):
+ super().setUp()
+ # Create Python, JIT and backend versions of BasicModule.
+ self.module = BasicModule()
+ self.scripted_module = torch.jit.script(BasicModule())
+ self.lowered_module = to_test_backend_multi(
+ self.scripted_module._c,
+ {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
+ )
+
+ def test_execution(self):
# Test execution with backend against Python and JIT.
input = torch.randn(5)
# Test all three module methods.
- self.compare_py_jit_backend("accum", input)
- self.compare_py_jit_backend("sub_accum", input)
- self.compare_py_jit_backend("forward", input)
+ self.check_function("accum", input)
+ self.check_function("sub_accum", input)
+ self.check_function("forward", input)
@skipIfRocm
def test_save_load(self):
- """
- This method tests that a lowered module till produces the same output as a Python module and ScriptModule after
- saving and loading.
- """
- # Save the lowered module.
- buffer = io.BytesIO()
- torch.jit.save(self.lowered_module, buffer)
+ # Lowered module should produce the same outputs.
+ self.test_execution()
# Save the compile spec to compare against the version retrieved after loading.
pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
- # Load the lowered module.
- buffer.seek(0)
- self.lowered_module = torch.jit.load(buffer)
+ # Save and load the lowered module.
+ self.save_load()
# Get the compile spec after loading.
post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
@@ -106,10 +124,82 @@
# Compile specs should match.
self.assertEqual(pre_compile_spec, post_compile_spec)
+ # Loaded module should produce the same outputs.
+ self.test_execution()
+
+
+class NestedModuleTest(JitBackendTestCase):
+ """
+ Tests for NestedModule that check that a module lowered to a backend can be used
+ as a submodule.
+ """
+ class NestedModule(torch.nn.Module):
+ """
+ A Module with one submodule that is used to test that lowered Modules
+ can be used as submodules.
+ """
+
+ def __init__(self, submodule):
+ super().__init__()
+ self.submodule = submodule
+
+ def forward(self, x, h):
+ return self.submodule.forward(x, h)
+
+ def setUp(self):
+ super().setUp()
+ # Create Python, JIT and backend versions of NestedModule.
+ # Both modules in self.module are regular Python modules.
+ self.module = NestedModuleTest.NestedModule(BasicModule())
+ # Both modules in self.scripted_module are ScriptModules.
+ self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule()))
+ lowered_module = to_test_backend_multi(
+ self.scripted_module._c, {"forward": {"": ""}}
+ )
+ # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
+ self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module))
+
+ def test_execution(self):
# Test execution with backend against Python and JIT.
input = torch.randn(5)
- # Test all three module methods.
- self.compare_py_jit_backend("accum", input)
- self.compare_py_jit_backend("sub_accum", input)
- self.compare_py_jit_backend("forward", input)
+ # Test forward.
+ self.check_function("forward", input)
+
+ def test_save_load(self):
+ # Lowered module should produce the same outputs.
+ self.test_execution()
+
+ # Save and load the lowered module.
+ self.save_load()
+
+ # Loaded module should produce the same outputs.
+ self.test_execution()
+
+
+class TestBackends(JitTestCase):
+ """
+ This class wraps and invokes all subclasses of JitBackendTestCase so that each one
+ does not have to be individually imported in test_jit.py.
+ """
+
+ def __init__(self, name):
+ super().__init__(name)
+ self.basic_module_test = BasicModuleTest(name)
+ self.nested_module_test = NestedModuleTest(name)
+
+ def setUp(self):
+ super().setUp()
+ if not TEST_WITH_ROCM:
+ self.basic_module_test.setUp()
+ self.nested_module_test.setUp()
+
+ @skipIfRocm
+ def test_execution(self):
+ self.basic_module_test.test_execution()
+ self.nested_module_test.test_execution()
+
+ @skipIfRocm
+ def test_save_load(self):
+ self.basic_module_test.test_save_load()
+ self.nested_module_test.test_save_load()
diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp
index 51e324b..b01cb62 100644
--- a/torch/csrc/jit/backends/backend_init.cpp
+++ b/torch/csrc/jit/backends/backend_init.cpp
@@ -17,115 +17,111 @@
// this function must be called like
//
// torch._C._jit_to_backend("example_backend", module, spec)
- auto m = py::handle(module).cast<py::module>();
- m.def(
- "_jit_to_backend",
- [=](const std::string& backend_name,
- const Module& orig_module,
- const py::dict& method_compile_spec) {
- const c10::QualifiedName qual_backend_name({"__torch__",
- "torch",
- "classes",
- detail::kBackendsNamespace,
- backend_name});
- // TODO: Validate method_compile_spec.
+ auto codegen_lambda = [=](const std::string& backend_name,
+ const Module& orig_module,
+ const py::dict& method_compile_spec) {
+ const c10::QualifiedName qual_backend_name({"__torch__",
+ "torch",
+ "classes",
+ detail::kBackendsNamespace,
+ backend_name});
+ // TODO: Validate method_compile_spec.
- // Clone orig_module to make sure backend transformation is
- // functional.
- auto cloned_module = orig_module.clone();
+ // Clone orig_module to make sure backend transformation is
+ // functional.
+ auto cloned_module = orig_module.clone();
- // Represents of a Type of Dict[str, Any].
- auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
+ // Represents of a Type of Dict[str, Any].
+ auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
- // Generate LoweredModule.
- Module loweredModule(
- "torch.jit." + backend_name + "LoweredModule",
- get_python_cu(),
- /*shouldMangle=*/true);
+ // Generate LoweredModule.
+ Module loweredModule(
+ "torch.jit." + backend_name + "LoweredModule",
+ get_python_cu(),
+ /*shouldMangle=*/true);
- // Generate attributes.
- // This is the original cloned and preprocessed module.
- loweredModule.register_attribute(
- "__processed_module",
- AnyType::get(),
- cloned_module._ivalue(),
- /*is_param=*/false);
+ // Generate attributes.
+ // This is the original cloned and preprocessed module.
+ loweredModule.register_attribute(
+ "__processed_module",
+ AnyType::get(),
+ cloned_module._ivalue(),
+ /*is_param=*/false);
- // This is for the method_compile_spec passed in to to_<backend> or
- // loaded from an exported model.
- loweredModule.register_attribute(
- "__method_compile_spec",
- any_dict_ty,
- toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
- /*is_param=*/false);
+ // This is for the method_compile_spec passed in to to_<backend> or
+ // loaded from an exported model.
+ loweredModule.register_attribute(
+ "__method_compile_spec",
+ any_dict_ty,
+ toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+ /*is_param=*/false);
- // This is a pointer to a backend instance that is used to access
- // compile and execute functions.
- auto cls = getCustomClass(qual_backend_name.qualifiedName());
- TORCH_INTERNAL_ASSERT(cls);
- c10::intrusive_ptr<torch::CustomClassHolder> backend;
- loweredModule.register_attribute(
- "__backend", cls, IValue::make_capsule(backend));
+ // This is a pointer to a backend instance that is used to access
+ // compile and execute functions.
+ auto cls = getCustomClass(qual_backend_name.qualifiedName());
+ TORCH_INTERNAL_ASSERT(cls);
+ c10::intrusive_ptr<torch::CustomClassHolder> backend;
+ loweredModule.register_attribute(
+ "__backend", cls, IValue::make_capsule(backend));
- // This is the list of opaque backend handles returned by
- // backend.compile.
- loweredModule.register_attribute(
- "__handles",
- any_dict_ty,
- c10::impl::GenericDict(
- any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
- /*is_param=*/false);
+ // This is the list of opaque backend handles returned by
+ // backend.compile.
+ loweredModule.register_attribute(
+ "__handles",
+ any_dict_ty,
+ c10::impl::GenericDict(
+ any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
+ /*is_param=*/false);
- // Methods.
+ // Methods.
- // This is a helper function for creating a new instance of the
- // backend class.
- static const auto create_backend_ct = CodeTemplate(R"(
+ // This is a helper function for creating a new instance of the
+ // backend class.
+ static const auto create_backend_ct = CodeTemplate(R"(
def __create_backend(self):
self.__backend = $name()
)");
- TemplateEnv create_backend_te;
- create_backend_te.s("name", qual_backend_name.qualifiedName());
- loweredModule.define(
- create_backend_ct.format(create_backend_te),
- loweredModuleResolver());
+ TemplateEnv create_backend_te;
+ create_backend_te.s("name", qual_backend_name.qualifiedName());
+ loweredModule.define(
+ create_backend_ct.format(create_backend_te), loweredModuleResolver());
- // getstate and setstate are for serialization/deserialization of
- // the LoweredModule.
- loweredModule.define(
- R"(
+ // getstate and setstate are for serialization/deserialization of
+ // the LoweredModule.
+ loweredModule.define(
+ R"(
def __getstate__(self):
return self.__method_compile_spec, self.__processed_module
)",
- loweredModuleResolver());
+ loweredModuleResolver());
- loweredModule.define(
- R"(
+ loweredModule.define(
+ R"(
def __setstate__(self, state):
self.__method_compile_spec = state[0]
self.__processed_module = state[1]
self.__create_backend()
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
)",
- loweredModuleResolver());
+ loweredModuleResolver());
- // This is never called during compilation or execution, but is
- // needed to generate the LoweredModule because we don't have access
- // to an instance of the backend as a C++ object with which to call
- // preprocess.
- loweredModule.define(
- R"(
+ // This is never called during compilation or execution, but is
+ // needed to generate the LoweredModule because we don't have access
+ // to an instance of the backend as a C++ object with which to call
+ // preprocess.
+ loweredModule.define(
+ R"(
def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
self.__create_backend()
self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
)",
- loweredModuleResolver());
+ loweredModuleResolver());
- // This loop generates one method on the LoweredModule for every key
- // in method_compile_spec.
- for (auto& e : method_compile_spec) {
- std::string method_name = py::cast<std::string>(e.first);
- static const auto method_ct = CodeTemplate(R"(
+ // This loop generates one method on the LoweredModule for every key
+ // in method_compile_spec.
+ for (auto& e : method_compile_spec) {
+ std::string method_name = py::cast<std::string>(e.first);
+ static const auto method_ct = CodeTemplate(R"(
def $method(self${,def_inputs}):
typed_inputs: List[Any] = [${fwd_inputs,}]
$ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
@@ -133,98 +129,108 @@
return $ret
)");
- TemplateEnv method_te;
- method_te.s("method", method_name);
- auto method = orig_module.get_method(method_name);
- auto& function = method.function();
- auto& schema = function.getSchema();
+ TemplateEnv method_te;
+ method_te.s("method", method_name);
+ auto method = orig_module.get_method(method_name);
+ auto& function = method.function();
+ auto& schema = function.getSchema();
- // Generate the inputs for the function signature (def_inputs) and
- // for passing to backend.execute (fwd_inputs).
- std::vector<std::string> def_inputs, fwd_inputs;
- for (const auto& arg : schema.arguments()) {
- auto name = arg.name();
+ // Generate the inputs for the function signature (def_inputs) and
+ // for passing to backend.execute (fwd_inputs).
+ std::vector<std::string> def_inputs, fwd_inputs;
+ for (const auto& arg : schema.arguments()) {
+ auto name = arg.name();
- // Skip self since that is only and always present in the
- // signature.
- if (name == "self") {
- continue;
- }
-
- auto default_value = arg.default_value();
-
- if (arg.kwarg_only()) {
- // If this is a kwarg, it needs to be emitted as keyword=value
- // in the definition and keyword=keyword in the call to
- // backend_execute.
- TORCH_INTERNAL_ASSERT(default_value.has_value());
- std::stringstream def_ss, fwd_ss;
- def_ss << name << "=";
- fwd_ss << name << "=" << name;
- default_value->repr(
- def_ss,
- [](std::ostream&, const IValue&) -> bool { return false; });
- def_inputs.emplace_back(def_ss.str());
- fwd_inputs.emplace_back(fwd_ss.str());
- } else {
- // If this is not a kwarg, it should be emitted as is in the
- // signature and the call to backend_execute.
- def_inputs.emplace_back(name);
- fwd_inputs.emplace_back(name);
- }
- }
-
- // Generate a comma-delimited list of identifiers to unpack
- // outputs, as well as a list of isinstance checks to make sure
- // the backend returned the types it was supposed to.
- std::stringstream out_ss, type_check_ss;
- std::vector<std::string> type_checks;
- TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
- auto out_ty = schema.returns().at(0).type();
-
- out_ss << "_0";
- type_check_ss << "assert isinstance(_0, ";
-
- if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
- auto tuple_elements = out_tuple_ty->elements();
- type_check_ss << tuple_elements[0]->str() << ")";
- type_checks.emplace_back(type_check_ss.str());
- for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
- type_check_ss.str(std::string());
- type_check_ss.clear();
- out_ss << ", _" << i;
- type_check_ss << "assert isinstance(_" << i << ", "
- << tuple_elements[i]->str() << ")";
- type_checks.emplace_back(type_check_ss.str());
- }
- } else {
- type_check_ss << out_ty->str() << ")";
- type_checks.emplace_back(type_check_ss.str());
- }
-
- method_te.v("def_inputs", def_inputs);
- method_te.v("fwd_inputs", fwd_inputs);
- method_te.v("refine", type_checks);
- method_te.s("ret", out_ss.str());
-
- loweredModule.define(
- method_ct.format(method_te), loweredModuleResolver());
+ // Skip self since that is only and always present in the
+ // signature.
+ if (name == "self") {
+ continue;
}
- // Run preprocess so that __processed_module is set correctly before
- // compilation.
- loweredModule.run_method(
- "__preprocess",
- cloned_module._ivalue(),
- toIValue(method_compile_spec, any_dict_ty).toGenericDict());
+ auto default_value = arg.default_value();
- // Call __setstate__ to ensure that the returned Module is ready to
- // run.
- auto state = at::ivalue::Tuple::create(
- toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
- loweredModule.attr("__processed_module"));
- loweredModule.run_method("__setstate__", state);
- return loweredModule;
+ if (arg.kwarg_only()) {
+ // If this is a kwarg, it needs to be emitted as keyword=value
+ // in the definition and keyword=keyword in the call to
+ // backend_execute.
+ TORCH_INTERNAL_ASSERT(default_value.has_value());
+ std::stringstream def_ss, fwd_ss;
+ def_ss << name << "=";
+ fwd_ss << name << "=" << name;
+ default_value->repr(def_ss, [](std::ostream&, const IValue&) -> bool {
+ return false;
+ });
+ def_inputs.emplace_back(def_ss.str());
+ fwd_inputs.emplace_back(fwd_ss.str());
+ } else {
+ // If this is not a kwarg, it should be emitted as is in the
+ // signature and the call to backend_execute.
+ def_inputs.emplace_back(name);
+ fwd_inputs.emplace_back(name);
+ }
+ }
+
+ // Generate a comma-delimited list of identifiers to unpack
+ // outputs, as well as a list of isinstance checks to make sure
+ // the backend returned the types it was supposed to.
+ std::stringstream out_ss, type_check_ss;
+ std::vector<std::string> type_checks;
+ TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
+ auto out_ty = schema.returns().at(0).type();
+
+ out_ss << "_0";
+ type_check_ss << "assert isinstance(_0, ";
+
+ if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
+ auto tuple_elements = out_tuple_ty->elements();
+ type_check_ss << tuple_elements[0]->str() << ")";
+ type_checks.emplace_back(type_check_ss.str());
+ for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
+ type_check_ss.str(std::string());
+ type_check_ss.clear();
+ out_ss << ", _" << i;
+ type_check_ss << "assert isinstance(_" << i << ", "
+ << tuple_elements[i]->str() << ")";
+ type_checks.emplace_back(type_check_ss.str());
+ }
+ } else {
+ type_check_ss << out_ty->str() << ")";
+ type_checks.emplace_back(type_check_ss.str());
+ }
+
+ method_te.v("def_inputs", def_inputs);
+ method_te.v("fwd_inputs", fwd_inputs);
+ method_te.v("refine", type_checks);
+ method_te.s("ret", out_ss.str());
+
+ loweredModule.define(
+ method_ct.format(method_te), loweredModuleResolver());
+ }
+
+ // Run preprocess so that __processed_module is set correctly before
+ // compilation.
+ loweredModule.run_method(
+ "__preprocess",
+ cloned_module._ivalue(),
+ toIValue(method_compile_spec, any_dict_ty).toGenericDict());
+
+ // Call __setstate__ to ensure that the returned Module is ready to
+ // run.
+ auto state = at::ivalue::Tuple::create(
+ toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+ loweredModule.attr("__processed_module"));
+ loweredModule.run_method("__setstate__", state);
+ return loweredModule;
+ };
+ auto m = py::handle(module).cast<py::module>();
+ m.def(
+ "_jit_to_backend",
+ [=](const std::string& backend_name,
+ const Module& orig_module,
+ const py::dict& method_compile_spec) {
+ return py::module::import("torch.jit._recursive")
+ .attr("wrap_cpp_module")(
+ codegen_lambda(backend_name, orig_module, method_compile_spec));
});
}
} // namespace jit
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 0a912b8..3e52a11 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -537,7 +537,8 @@
}
visited.emplace(item.a.internalToPointer());
}
- if (*unshapedType(item.a.type()) != *unshapedType(item.b.type())) {
+ if (!unshapedType(item.b.type())
+ ->isSubtypeOf(unshapedType(item.b.type()))) {
// Since named types are saved and loaded in the test suite, we cannot
// expect them to be equal. We should still check their slots however.
if (!item.a.type()->cast<c10::NamedType>()) {