Python basic module execution unit test on delegation of backend_with_compiler_demo (#60468)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60468
Added a unit test for the execution of a basic module with a compiler
ghstack-source-id: 132307488
Test Plan:
Running python test/test_jit.py TestBackendsWithCompiler -v returns a successful test
Imported from OSS
Reviewed By: iseeyuan
Differential Revision: D29306225
fbshipit-source-id: bf1ff075ebc63acbbe46d6ea030086405e29d7d3
diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py
index 0f67cd0..bcf2865 100644
--- a/test/jit/test_backends.py
+++ b/test/jit/test_backends.py
@@ -7,6 +7,7 @@
import torch
import torch._C
from torch.testing import FileCheck
+from torch.jit.mobile import _load_for_lite_interpreter
from pathlib import Path
from torch.testing._internal.common_utils import (
@@ -465,3 +466,121 @@
@skipIfRocm
def test_errors(self):
self.selective_lowering_test.test_errors()
+
+"""
+Unit Tests for backend with compiler
+This test case and the existing TestBackends are separate because they cover different aspects.
+The actual backend implementation in this test is different.
+It has a simple demo compiler to test the end-to-end flow in mobile.
+However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
+"""
+class BasicModuleAdd(torch.nn.Module):
+ """
+ A simple add Module used to test to_backend lowering machinery.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, h):
+ return x + h
+
+# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
+@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
+ "Non-portable load_library call used in test")
+class JitBackendTestCaseWithCompiler(JitTestCase):
+ """
+ A common base class for JIT backend tests with compilers that contains common utility
+ functions for output comparison.
+ """
+
+ def setUp(self):
+ super().setUp()
+ torch_root = Path(__file__).resolve().parent.parent.parent
+ p = torch_root / 'build' / 'lib' / 'libbackend_with_compiler.so'
+ torch.ops.load_library(str(p))
+ # Subclasses are expected to set up four variables in their setUp methods:
+ # module - a regular, Python version of the module being tested
+ # scripted_module - a scripted version of module
+ # lowered_modle - a version of module lowered to a backend
+ # mobile_module - a module with a format that Pytorch Mobile can execute
+
+ def check_forward(self, input):
+ """
+ Check that the forward function produces the same output using
+ Python, regular JIT, the backend, and mobile for the given 'input'.
+ """
+
+ # Get outputs from forward.
+ python_output = self.module.forward(*input)
+ jit_output = self.scripted_module.forward(*input)
+ backend_output = self.lowered_module(*input)
+ mobile_output = self.mobile_module(*input)
+
+ # The answers returned by Python, JIT, to_backend, and mobile should all match.
+ self.assertEqual(python_output, backend_output)
+ self.assertEqual(jit_output, backend_output)
+ self.assertEqual(mobile_output, backend_output)
+
+ def test_execution(self):
+ """
+ Stub for correctness tests.
+ """
+ pass
+
+ def test_errors(self):
+ """
+ Stub for testing error checking.
+ """
+ pass
+
+class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
+ """
+ Tests for BasicModuleAdd.
+ """
+
+ def setUp(self):
+ super().setUp()
+ # Create Python, JIT and backend versions of BasicModuleAdd.
+ self.module = BasicModuleAdd()
+ self.scripted_module = torch.jit.script(BasicModuleAdd())
+ compile_spec = {
+ "forward": {
+ "input_shapes": "((1, 1, 320, 240), (1, 3))",
+ "some_other_option": "True",
+ },
+ }
+ self.lowered_module = torch._C._jit_to_backend(
+ "backend_with_compiler_demo", self.scripted_module, compile_spec)
+ # Create mobile version of BasicModuleAdd
+ buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
+ buffer.seek(0)
+ self.mobile_module = _load_for_lite_interpreter(buffer)
+
+ def test_execution(self):
+ # Test execution with backend against Python and JIT.
+ input = torch.randn(5)
+ self.check_forward((input, input))
+
+
+# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
+@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
+ "Non-portable load_library call used in test")
+class TestBackendsWithCompiler(JitTestCase):
+ """
+ This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
+ so that each one does not have to be individually imported in test_jit.py.
+ """
+
+ def __init__(self, name):
+ super().__init__(name)
+ self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
+
+ def setUp(self):
+ super().setUp()
+ if not TEST_WITH_ROCM:
+ self.basic_module_compiler_test.setUp()
+
+ @skipIfRocm
+ def test_execution(self):
+ self.basic_module_compiler_test.test_execution()
diff --git a/test/test_jit.py b/test/test_jit.py
index b8ee357..bf70b70 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8,7 +8,7 @@
from jit.test_recursive_script import TestRecursiveScript # noqa: F401
from jit.test_type_sharing import TestTypeSharing # noqa: F401
from jit.test_logging import TestLogging # noqa: F401
-from jit.test_backends import TestBackends # noqa: F401
+from jit.test_backends import TestBackends, TestBackendsWithCompiler # noqa: F401
from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict # noqa: F401
from jit.test_async import TestAsync # noqa: F401
from jit.test_data_parallel import TestDataParallel # noqa: F401