blob: fbc1443024cb5a87dfb2e741c936219c38a99d5b [file] [log] [blame]
# Owner(s): ["oncall: jit"]
from typing import NamedTuple, Optional
import io
import os
import pathlib
import sys
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import (JitTestCase,
clear_class_registry)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestSaveLoad(JitTestCase):
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
return x
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_functions(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
def lol(x): # noqa: F811
return "hello"
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_interfaces(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface(object):
def __init__(self):
pass
def bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.bar(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface(object): # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.not_bar(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_many_collisions(self):
class MyCoolNamedTuple(NamedTuple):
a: int
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface(object):
def __init__(self):
pass
def bar(self, x):
return x
def lol(x):
return x
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
x = lol(x)
x = self.interface.bar(x)
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface(object): # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
def lol(x): # noqa: F811
return "asdofij"
class MyCoolNamedTuple(NamedTuple): # noqa: F811
a: str
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
self.interface.not_bar(x)
x = lol(x)
return x, MyCoolNamedTuple(a="hello")
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(second_script_module, second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x, named_tuple_1 = self.first(x)
x, named_tuple_2 = self.second(x)
return len(x + named_tuple_2.a) + named_tuple_1.a
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_save_load_with_extra_files(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return a
# specifically test binary data
value = b"bar\x00\xffbaz"
expected_extra_files = {}
expected_extra_files['foo'] = value
# verify that str to bytes conversion also works
expected_extra_files['foo2'] = "bar"
m = MyMod()
# Save to file.
with TemporaryFileName() as fname:
m.save(fname, _extra_files=expected_extra_files)
# values don't matter
extra_files = {'foo': '', 'foo2': None}
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# results come back always as bytes
self.assertEqual(b"bar", extra_files['foo2'])
# Use torch.jit API
torch.jit.save(m, fname, _extra_files=expected_extra_files)
extra_files['foo'] = ''
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Save to buffer.
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
extra_files = {'foo': ''}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Use torch.jit API
buffer = io.BytesIO()
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
buffer.seek(0)
extra_files = {'foo': ''}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Non-existent file 'bar'
with self.assertRaises(RuntimeError):
extra_files['bar'] = ''
torch.jit.load(buffer, _extra_files=extra_files)
def test_save_load_using_pathlib(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return 2 * a
m = MyMod()
# Save then load.
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
m.save(path)
m2 = torch.jit.load(path)
x = torch.tensor([1., 2., 3., 4.])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_nonexit_file(self):
class Foo(torch.nn.Module):
def forward(self, x):
return 2 * x
script_module = torch.jit.script(Foo())
with self.assertRaises(RuntimeError):
script_module.save("NonExist/path/test.pt")
def test_save_namedtuple_input_only(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self) -> Optional[FooTuple]:
return None
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded()
self.assertEqual(output, None)
def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
"""
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
self.assertEqual(m_name, loaded_name)
# Check parameters.
self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b
self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer)