blob: bb345292999d9a5a6ec9df3b2dd8505b7bf5c794 [file] [log] [blame]
from torch.testing._internal.jit_utils import JitTestCase
import io
import os
import sys
import torch
import torch._C
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_test_backend(module, {"forward": method_compile_spec})
def to_test_backend_multi(module, method_compile_spec):
return torch._C._jit_to_test_backend(module, method_compile_spec)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
class TestBackends(JitTestCase):
def setUp(self):
super().setUp()
# Create Python, JIT and backend versions of MyModule.
self.module = MyModule()
self.scripted_module = torch.jit.script(MyModule())
self.lowered_module = to_test_backend_multi(
self.scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
def compare_py_jit_backend(self, name, input):
"""
This is a helper function for comparing the outputs of self.module (Python), self.scripted_module (JIT)
and self.lowered_module (backend) when the method named 'name' is invoked using 'input'.
"""
# Get handles for Python, JIT and backend methods.
python_method = self.module.__getattribute__(name)
jit_method = self.scripted_module.__getattr__(name)
backend_method = self.lowered_module.__getattr__(name)
# Run methods.
python_output = python_method(input, input)
jit_output = jit_method(input, input)
backend_output = backend_method(input, input)
# The answers returned by Python, JIT and to_backend should all match.
self.assertEqual(python_output, backend_output)
self.assertEqual(jit_output, backend_output)
def test_simple(self):
"""
This is a simple test that compiles MyModule for the test backend and ensures it produces the correct
answers for each method.
"""
# Test execution with backend against Python and JIT.
input = torch.randn(5)
# Test all three module methods.
self.compare_py_jit_backend("accum", input)
self.compare_py_jit_backend("sub_accum", input)
self.compare_py_jit_backend("forward", input)
def test_save_load(self):
"""
This method tests that a lowered module till produces the same output as a Python module and ScriptModule after
saving and loading.
"""
# Save the lowered module.
buffer = io.BytesIO()
torch.jit.save(self.lowered_module, buffer)
# Save the compile spec to compare against the version retrieved after loading.
pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
# Load the lowered module.
buffer.seek(0)
self.lowered_module = torch.jit.load(buffer)
# Get the compile spec after loading.
post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
# Compile specs should match.
self.assertEqual(pre_compile_spec, post_compile_spec)
# Test execution with backend against Python and JIT.
input = torch.randn(5)
# Test all three module methods.
self.compare_py_jit_backend("accum", input)
self.compare_py_jit_backend("sub_accum", input)
self.compare_py_jit_backend("forward", input)