blob: 78267586b9e5697a508f4666cc45e2862f07dcf5 [file] [log] [blame]
from torch.testing._internal.jit_utils import JitTestCase
import os
import sys
import torch
import torch._C
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_test_backend(module, {"forward": method_compile_spec})
def to_test_backend_multi(module, method_compile_spec):
return torch._C._jit_to_test_backend(module, method_compile_spec)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
class TestBackends(JitTestCase):
def test_simple(self):
module = MyModule()
scripted_module = torch.jit.script(MyModule())
# Test compile.
lowered_module = to_test_backend_multi(
scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
# Test execution with backend against Python and JIT.
input = torch.randn(5)
def compare_py_jit_backend(name, input):
# Get handles for Python, JIT and backend methods.
python_method = module.__getattribute__(name)
jit_method = scripted_module.__getattr__(name)
backend_method = lowered_module._get_method(name)
# Run methods.
python_output = python_method(input, input)
jit_output = jit_method(input, input)
backend_output = backend_method(input, input)
# The answers returned by Python, JIT and to_backend should all match.
self.assertEqual(python_output, backend_output)
self.assertEqual(jit_output, backend_output)
# Test all three module methods.
compare_py_jit_backend("accum", input)
compare_py_jit_backend("sub_accum", input)
compare_py_jit_backend("forward", input)
# TODO: Test save and load.