blob: a21c3dc393395fb095fa79b5907af5761347932b [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import io
import os
import pathlib
import sys
import unittest
from typing import NamedTuple, Optional
import torch
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
ENABLE_FLATBUFFER = os.environ.get("ENABLE_FLATBUFFER", "0") == "1"
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestSaveLoad(JitTestCase):
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
return x
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_functions(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
def lol(x): # noqa: F811
return "hello"
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_interfaces(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface:
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface:
def __init__(self):
pass
def bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.bar(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface:
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface: # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.not_bar(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_many_collisions(self):
class MyCoolNamedTuple(NamedTuple):
a: int
@torch.jit.interface
class MyInterface:
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface:
def __init__(self):
pass
def bar(self, x):
return x
def lol(x):
return x
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
x = lol(x)
x = self.interface.bar(x)
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface:
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface: # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
def lol(x): # noqa: F811
return "asdofij"
class MyCoolNamedTuple(NamedTuple): # noqa: F811
a: str
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
self.interface.not_bar(x)
x = lol(x)
return x, MyCoolNamedTuple(a="hello")
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(second_script_module, second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x, named_tuple_1 = self.first(x)
x, named_tuple_2 = self.second(x)
return len(x + named_tuple_2.a) + named_tuple_1.a
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_save_load_with_extra_files(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return a
# specifically test binary data
value = b"bar\x00\xffbaz"
expected_extra_files = {}
expected_extra_files["foo"] = value
# verify that str to bytes conversion also works
expected_extra_files["foo2"] = "bar"
m = MyMod()
# Save to file.
with TemporaryFileName() as fname:
m.save(fname, _extra_files=expected_extra_files)
# values don't matter
extra_files = {"foo": "", "foo2": None}
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files["foo"])
# results come back always as bytes
self.assertEqual(b"bar", extra_files["foo2"])
# Use torch.jit API
torch.jit.save(m, fname, _extra_files=expected_extra_files)
extra_files["foo"] = ""
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files["foo"])
# Save to buffer.
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
extra_files = {"foo": ""}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files["foo"])
# Use torch.jit API
buffer = io.BytesIO()
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
buffer.seek(0)
extra_files = {"foo": ""}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files["foo"])
# Non-existent file 'bar'
with self.assertRaises(RuntimeError):
extra_files["bar"] = ""
torch.jit.load(buffer, _extra_files=extra_files)
def test_save_load_using_pathlib(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return 2 * a
m = MyMod()
# Save then load.
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
m.save(path)
m2 = torch.jit.load(path)
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_nonexit_file(self):
class Foo(torch.nn.Module):
def forward(self, x):
return 2 * x
script_module = torch.jit.script(Foo())
with self.assertRaises(RuntimeError):
script_module.save("NonExist/path/test.pt")
def test_save_namedtuple_input_only(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self) -> Optional[FooTuple]:
return None
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded()
self.assertEqual(output, None)
def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
"""
class Submodule(torch.nn.Module):
pass
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter(
"parameter_a", torch.nn.Parameter(torch.randn(4))
)
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
self.assertEqual(m_name, loaded_name)
# Check parameters.
self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
)
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b
self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer)
def test_save_load_meta_tensors(self):
"""
Check that parameters, buffers, and submodules are the same after loading
for a module with parameters and buffers that are meta tensors
"""
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)
self.register_buffer("buffer", torch.randn(4, device="meta"))
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
m = Foo()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
{name for name, _ in m.named_modules()},
{name for name, _ in m_loaded.named_modules()},
)
# Check parameters.
m_params = dict(m.named_parameters())
m_loaded_params = dict(m_loaded.named_parameters())
self.assertEqual(len(m_params), len(m_loaded_params))
self.assertEqual(m_params, m_loaded_params)
# Check buffers.
m_buffers = dict(m.named_buffers())
m_loaded_buffers = dict(m_loaded.named_buffers())
self.assertEqual(len(m_buffers), len(m_loaded_buffers))
self.assertEqual(m_buffers, m_loaded_buffers)
# Check params and buffers that are/are not meta tensors
self.assertTrue(m_params["foo.weight"].is_meta)
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
self.assertTrue(m_params["foo.bias"].is_meta)
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
self.assertFalse(m_params["bar.weight"].is_meta)
self.assertFalse(m_loaded_params["bar.weight"].is_meta)
self.assertFalse(m_params["bar.bias"].is_meta)
self.assertFalse(m_loaded_params["bar.bias"].is_meta)
self.assertTrue(m_buffers["buffer"].is_meta)
self.assertTrue(m_loaded_buffers["buffer"].is_meta)
def test_save_load_with_saved_traced_inputs(self):
"""
Check that saving and loading with traced inputs works as expected
"""
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.ones(1)
def get_loaded_inputs(inputs):
traced_module = torch.jit.trace(module, input1)
traced_inputs = list(traced_module.graph.inputs())
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
traced_module.save(path)
print(traced_module.graph)
loaded_module = torch.jit.load(path, _restore_shapes=True)
print(loaded_module.graph)
return traced_inputs, list(loaded_module.graph.inputs())
module = Module()
input_tensor = torch.rand(1, 3, 24, 24)
# Validate that with no input specified the traced inputs are stored
traced_module = torch.jit.trace(module, input_tensor)
traced_inputs = list(traced_module.graph.inputs())
self.assertEquals(traced_module._c._retrieve_traced_inputs()['forward'], [input_tensor])
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
traced_module.save(path)
loaded_module = torch.jit.load(path, _restore_shapes=True)
loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
self.assertEqual(traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes())
# Validate that if no shapes are requested previous functionality remains
loaded_module = torch.jit.load(path)
loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(loaded_inputs[1].type().sizes(), None)
# Validate that inputs aren't saved when requested not to
traced_module = torch.jit.trace(module, input_tensor, _store_inputs=False)
traced_inputs = list(traced_module.graph.inputs())
self.assertEquals(len(traced_module._c._retrieve_traced_inputs()), 0)
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
traced_module.save(path)
loaded_module = torch.jit.load(path, _restore_shapes=True)
loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(loaded_inputs[1].type().sizes(), None)
# Validate that if no shapes are requested previous functionality remains
loaded_module = torch.jit.load(path)
loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(loaded_inputs[1].type().sizes(), None)
# Validate that complex inputs work
# Testing dict of list with empty tensors
input1 = {
"1000": (
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
torch.tensor([])
)
}
traced_inputs, loaded_inputs = get_loaded_inputs(input1)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing dict of list
input2 = {
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])
)
}
traced_inputs, loaded_inputs = get_loaded_inputs(input2)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list
input3 = [torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])]
traced_inputs, loaded_inputs = get_loaded_inputs(input3)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list of dict of list
input4 = [{
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])
)
}]
traced_inputs, loaded_inputs = get_loaded_inputs(input4)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
def script_module_to_buffer(script_module):
module_buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True)
)
module_buffer.seek(0)
return module_buffer
@unittest.skipIf(
not ENABLE_FLATBUFFER, "Need to enable flatbuffer to run the below tests"
)
class TestSaveLoadFlatbuffer(JitTestCase):
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
return x
second_script_module = torch.jit.script(Foo())
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_different_functions(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
def lol(x): # noqa: F811
return "hello"
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_different_interfaces(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface:
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface:
def __init__(self):
pass
def bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.bar(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
@torch.jit.interface
class MyInterface:
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface: # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.not_bar(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_many_collisions(self):
class MyCoolNamedTuple(NamedTuple):
a: int
@torch.jit.interface
class MyInterface:
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface:
def __init__(self):
pass
def bar(self, x):
return x
def lol(x):
return x
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
x = lol(x)
x = self.interface.bar(x)
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
@torch.jit.interface
class MyInterface:
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface: # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
def lol(x): # noqa: F811
return "asdofij"
class MyCoolNamedTuple(NamedTuple): # noqa: F811
a: str
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
self.interface.not_bar(x)
x = lol(x)
return x, MyCoolNamedTuple(a="hello")
second_script_module = torch.jit.script(Foo())
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x, named_tuple_1 = self.first(x)
x, named_tuple_2 = self.second(x)
return len(x + named_tuple_2.a) + named_tuple_1.a
sm = torch.jit.script(ContainsBoth())
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_save_load_using_pathlib(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return 2 * a
m = MyMod()
# Save then load.
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
torch.jit.save_jit_module_to_flatbuffer(m, path)
m2 = torch.jit.load(path)
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_namedtuple_input_only(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self) -> Optional[FooTuple]:
return None
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded()
self.assertEqual(output, None)
def test_module_info_flatbuffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
first_saved_module.seek(0)
expected = {
'bytecode_version': 4,
'operator_version': 4,
'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
'type_names': set(),
'opname_to_num_args': {'aten::linear': 3}}
self.assertEqual(
torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
expected)
def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
"""
class Submodule(torch.nn.Module):
pass
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter(
"parameter_a", torch.nn.Parameter(torch.randn(4))
)
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
self.assertEqual(m_name, loaded_name)
# Check parameters.
self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
)
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b
self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer)
def test_save_load_with_extra_files(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
"""
class Module(torch.nn.Module):
def forward(self, x: Tensor):
return x
module = Module()
script_module = torch.jit.script(module)
script_module_io = io.BytesIO()
extra_files = {"abc.json": "[1,2,3]"}
script_module._save_for_lite_interpreter(script_module_io, _extra_files=extra_files, _use_flatbuffer=True)
script_module_io.seek(0)
re_extra_files = {}
torch._C._get_model_extra_files_from_buffer(script_module_io, _extra_files=re_extra_files)
self.assertEqual(extra_files, re_extra_files)