| from torch.testing._internal.jit_utils import JitTestCase |
| import os |
| import sys |
| import unittest |
| |
| import torch |
| import torch._C |
| from pathlib import Path |
| from torch.testing._internal.common_utils import ( |
| IS_FBCODE, |
| IS_MACOS, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| TEST_WITH_ROCM, |
| skipIfRocm, |
| ) |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| |
| def to_test_backend(module, method_compile_spec): |
| return torch._C._jit_to_backend("test_backend", module, {"forward": method_compile_spec}) |
| |
| |
| def to_test_backend_multi(module, method_compile_spec): |
| return torch._C._jit_to_backend("test_backend", module, method_compile_spec) |
| |
| |
| class BasicModule(torch.nn.Module): |
| """ |
| A simple Module used to test to_backend lowering machinery. |
| """ |
| |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, h): |
| return self.accum(x, h), self.sub_accum(x, h) |
| |
| def accum(self, x, h): |
| return x + h |
| |
| def sub_accum(self, x, h): |
| return x - h |
| |
| |
| class JitBackendTestCase(JitTestCase): |
| """ |
| A common base class for JIT backend tests that contains common utility |
| functions for output comparison and serialization/deserialization. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE: |
| raise unittest.SkipTest("non-portable load_library call used in test") |
| torch_root = Path(__file__).resolve().parent.parent.parent |
| p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so' |
| torch.ops.load_library(str(p)) |
| # Subclasses are expected to set up three variables in their setUp methods: |
| # module - a regular, Python version of the module being tested |
| # scripted_module - a scripted version of module |
| # lowered_modle - a version of module lowered to a backend |
| |
| def check_function(self, function_name, input): |
| """ |
| Check that the function named 'function_name' produces the same output using |
| Python, regular JIT and the backend for the given 'input'. |
| """ |
| # Get handles for Python, JIT and backend methods. |
| python_method = self.module.__getattribute__(function_name) |
| jit_method = self.scripted_module.__getattr__(function_name) |
| backend_method = self.lowered_module.__getattr__(function_name) |
| |
| # Run methods. |
| python_output = python_method(input, input) |
| jit_output = jit_method(input, input) |
| backend_output = backend_method(input, input) |
| |
| # The answers returned by Python, JIT and to_backend should all match. |
| self.assertEqual(python_output, backend_output) |
| self.assertEqual(jit_output, backend_output) |
| |
| def save_load(self): |
| """ |
| Save and load the lowered module. |
| """ |
| self.lowered_module = self.getExportImportCopy(self.lowered_module) |
| |
| |
| class BasicModuleTest(JitBackendTestCase): |
| """ |
| Tests for BasicModule. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of BasicModule. |
| self.module = BasicModule() |
| self.scripted_module = torch.jit.script(BasicModule()) |
| self.lowered_module = to_test_backend_multi( |
| self.scripted_module, |
| {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, |
| ) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input = torch.randn(5) |
| |
| # Test all three module methods. |
| self.check_function("accum", input) |
| self.check_function("sub_accum", input) |
| self.check_function("forward", input) |
| |
| @skipIfRocm |
| def test_save_load(self): |
| # Lowered module should produce the same outputs. |
| self.test_execution() |
| |
| # Save the compile spec to compare against the version retrieved after loading. |
| pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec") |
| |
| # Save and load the lowered module. |
| self.save_load() |
| |
| # Get the compile spec after loading. |
| post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec") |
| |
| # Compile specs should match. |
| self.assertEqual(pre_compile_spec, post_compile_spec) |
| |
| # Loaded module should produce the same outputs. |
| self.test_execution() |
| |
| |
| class NestedModuleTest(JitBackendTestCase): |
| """ |
| Tests for NestedModule that check that a module lowered to a backend can be used |
| as a submodule. |
| """ |
| class NestedModule(torch.nn.Module): |
| """ |
| A Module with one submodule that is used to test that lowered Modules |
| can be used as submodules. |
| """ |
| |
| def __init__(self, submodule): |
| super().__init__() |
| self.submodule = submodule |
| |
| def forward(self, x, h): |
| return self.submodule.forward(x, h) |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of NestedModule. |
| # Both modules in self.module are regular Python modules. |
| self.module = NestedModuleTest.NestedModule(BasicModule()) |
| # Both modules in self.scripted_module are ScriptModules. |
| self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule())) |
| lowered_module = to_test_backend_multi( |
| self.scripted_module, {"forward": {"": ""}} |
| ) |
| # self.lowered_module is a ScriptModule, but its submodule is a lowered module. |
| self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module)) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input = torch.randn(5) |
| |
| # Test forward. |
| self.check_function("forward", input) |
| |
| def test_save_load(self): |
| # Lowered module should produce the same outputs. |
| self.test_execution() |
| |
| # Save and load the lowered module. |
| self.save_load() |
| |
| # Loaded module should produce the same outputs. |
| self.test_execution() |
| |
| |
| class TestBackends(JitTestCase): |
| """ |
| This class wraps and invokes all subclasses of JitBackendTestCase so that each one |
| does not have to be individually imported in test_jit.py. |
| """ |
| |
| def __init__(self, name): |
| super().__init__(name) |
| self.basic_module_test = BasicModuleTest(name) |
| self.nested_module_test = NestedModuleTest(name) |
| |
| def setUp(self): |
| super().setUp() |
| if not TEST_WITH_ROCM: |
| self.basic_module_test.setUp() |
| self.nested_module_test.setUp() |
| |
| @skipIfRocm |
| def test_execution(self): |
| self.basic_module_test.test_execution() |
| self.nested_module_test.test_execution() |
| |
| @skipIfRocm |
| def test_save_load(self): |
| self.basic_module_test.test_save_load() |
| self.nested_module_test.test_save_load() |