Python basic module execution unit test on delegation of backend_with_compiler_demo (#60468)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60468

Added a unit test for the execution of a basic module with a compiler
ghstack-source-id: 132307488

Test Plan:
Running python test/test_jit.py TestBackendsWithCompiler -v returns a successful test

Imported from OSS

Reviewed By: iseeyuan

Differential Revision: D29306225

fbshipit-source-id: bf1ff075ebc63acbbe46d6ea030086405e29d7d3
diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py
index 0f67cd0..bcf2865 100644
--- a/test/jit/test_backends.py
+++ b/test/jit/test_backends.py
@@ -7,6 +7,7 @@
 import torch
 import torch._C
 from torch.testing import FileCheck
+from torch.jit.mobile import _load_for_lite_interpreter
 from pathlib import Path
 
 from torch.testing._internal.common_utils import (
@@ -465,3 +466,121 @@
     @skipIfRocm
     def test_errors(self):
         self.selective_lowering_test.test_errors()
+
+"""
+Unit Tests for backend with compiler
+This test case and the existing TestBackends are separate because they cover different aspects.
+The actual backend implementation in this test is different.
+It has a simple demo compiler to test the end-to-end flow in mobile.
+However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
+"""
+class BasicModuleAdd(torch.nn.Module):
+    """
+    A simple add Module used to test to_backend lowering machinery.
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x, h):
+        return x + h
+
+# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
+@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
+                 "Non-portable load_library call used in test")
+class JitBackendTestCaseWithCompiler(JitTestCase):
+    """
+    A common base class for JIT backend tests with compilers that contains common utility
+    functions for output comparison.
+    """
+
+    def setUp(self):
+        super().setUp()
+        torch_root = Path(__file__).resolve().parent.parent.parent
+        p = torch_root / 'build' / 'lib' / 'libbackend_with_compiler.so'
+        torch.ops.load_library(str(p))
+        # Subclasses are expected to set up four variables in their setUp methods:
+        # module - a regular, Python version of the module being tested
+        # scripted_module - a scripted version of module
+        # lowered_modle - a version of module lowered to a backend
+        # mobile_module - a module with a format that Pytorch Mobile can execute
+
+    def check_forward(self, input):
+        """
+        Check that the forward function produces the same output using
+        Python, regular JIT, the backend, and mobile for the given 'input'.
+        """
+
+        # Get outputs from forward.
+        python_output = self.module.forward(*input)
+        jit_output = self.scripted_module.forward(*input)
+        backend_output = self.lowered_module(*input)
+        mobile_output = self.mobile_module(*input)
+
+        # The answers returned by Python, JIT, to_backend, and mobile should all match.
+        self.assertEqual(python_output, backend_output)
+        self.assertEqual(jit_output, backend_output)
+        self.assertEqual(mobile_output, backend_output)
+
+    def test_execution(self):
+        """
+        Stub for correctness tests.
+        """
+        pass
+
+    def test_errors(self):
+        """
+        Stub for testing error checking.
+        """
+        pass
+
+class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
+    """
+    Tests for BasicModuleAdd.
+    """
+
+    def setUp(self):
+        super().setUp()
+        # Create Python, JIT and backend versions of BasicModuleAdd.
+        self.module = BasicModuleAdd()
+        self.scripted_module = torch.jit.script(BasicModuleAdd())
+        compile_spec = {
+            "forward": {
+                "input_shapes": "((1, 1, 320, 240), (1, 3))",
+                "some_other_option": "True",
+            },
+        }
+        self.lowered_module = torch._C._jit_to_backend(
+            "backend_with_compiler_demo", self.scripted_module, compile_spec)
+        # Create mobile version of BasicModuleAdd
+        buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
+        buffer.seek(0)
+        self.mobile_module = _load_for_lite_interpreter(buffer)
+
+    def test_execution(self):
+        # Test execution with backend against Python and JIT.
+        input = torch.randn(5)
+        self.check_forward((input, input))
+
+
+# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
+@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
+                 "Non-portable load_library call used in test")
+class TestBackendsWithCompiler(JitTestCase):
+    """
+    This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
+    so that each one does not have to be individually imported in test_jit.py.
+    """
+
+    def __init__(self, name):
+        super().__init__(name)
+        self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
+
+    def setUp(self):
+        super().setUp()
+        if not TEST_WITH_ROCM:
+            self.basic_module_compiler_test.setUp()
+
+    @skipIfRocm
+    def test_execution(self):
+        self.basic_module_compiler_test.test_execution()
diff --git a/test/test_jit.py b/test/test_jit.py
index b8ee357..bf70b70 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8,7 +8,7 @@
 from jit.test_recursive_script import TestRecursiveScript  # noqa: F401
 from jit.test_type_sharing import TestTypeSharing  # noqa: F401
 from jit.test_logging import TestLogging  # noqa: F401
-from jit.test_backends import TestBackends  # noqa: F401
+from jit.test_backends import TestBackends, TestBackendsWithCompiler  # noqa: F401
 from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict  # noqa: F401
 from jit.test_async import TestAsync  # noqa: F401
 from jit.test_data_parallel import TestDataParallel  # noqa: F401