blob: a3cdde9166ffe4203eb5ffc1973534462124a87a [file] [log] [blame]
import torch
import torch.utils.bundled_inputs
from torch.utils.mobile_optimizer import *
import io
from typing import NamedTuple
from collections import namedtuple
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests
class TestLiteScriptModule(TestCase):
def test_load_mobile_module(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 10
input = torch.tensor([1])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_allclose(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
def test_save_mobile_module_with_debug_info_with_trace(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x):
return x + 1
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
self.A1 = A()
def forward(self, x):
return self.A0(x) + self.A1(x)
input = torch.tensor([5])
trace_module = torch.jit.trace(B(), input)
exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
assert(b"mobile_debug.pkl" in exported_module)
assert(b"module_debug_info" in exported_module)
assert(b"top(B).forward" in exported_module)
assert(b"top(B).A0(A).forward" in exported_module)
assert(b"top(B).A1(A).forward" in exported_module)
def test_save_mobile_module_with_debug_info_with_script_duplicate_class(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x):
return x + 1
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
self.A1 = A()
def forward(self, x):
return self.A0(x) + self.A1(x)
input_data = torch.tensor([5])
scripted_module = torch.jit.script(B(), input_data)
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
assert(b"mobile_debug.pkl" in exported_module)
assert(b"module_debug_info" in exported_module)
assert(b"top(B).forward" in exported_module)
assert(b"top(B).A0(A).forward" in exported_module)
assert(b"top(B).A1(A).forward" in exported_module)
def test_save_mobile_module_with_debug_info_with_script_nested_call(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x):
return x + 1
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
def forward(self, x):
return x + 2
class C(torch.nn.Module):
def __init__(self):
super(C, self).__init__()
self.A0 = A()
self.B0 = B()
def forward(self, x):
return self.A0(self.B0(x)) + 1
input = torch.tensor([5])
scripted_module = torch.jit.script(C(), input)
optimized_scripted_module = optimize_for_mobile(scripted_module)
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
assert(b"mobile_debug.pkl" in exported_module)
assert(b"module_debug_info" in exported_module)
assert(b"top(C).forward" in exported_module)
assert(b"top(C).A0(A).forward" in exported_module)
assert(b"top(C).B0(B).forward" in exported_module)
assert(b"mobile_debug.pkl" in optimized_exported_module)
assert(b"module_debug_info" in optimized_exported_module)
assert(b"top(C).forward" in optimized_exported_module)
assert(b"top(C).A0(A).forward" in optimized_exported_module)
assert(b"top(C).B0(B).forward" in optimized_exported_module)
def test_load_mobile_module_with_debug_info(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 5
input = torch.tensor([3])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_allclose(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
def test_find_and_run_method(self):
class MyTestModule(torch.nn.Module):
def forward(self, arg):
return arg
input = (torch.tensor([1]), )
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(*input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertFalse(has_bundled_inputs)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
script_module, [input], [])
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertTrue(has_bundled_inputs)
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
torch.testing.assert_allclose(script_module_result, mobile_module_result)
def test_method_calls_with_optional_arg(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
# opt arg in script-to-script invocation
def forward(self, x, two: int = 2):
return x + two
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
# opt arg in Python-to-script invocation
def forward(self, x, one: int = 1):
return self.A0(x) + one
script_module = torch.jit.script(B())
buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter()
)
mobile_module = _load_for_lite_interpreter(buffer)
input = torch.tensor([5])
script_module_forward_result = script_module.forward(input)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_allclose(
script_module_forward_result,
mobile_module_forward_result
)
# change ref only
script_module_forward_result = script_module.forward(input, 2)
self.assertFalse(
(script_module_forward_result == mobile_module_forward_result)
.all()
.item()
)
# now both match again
mobile_module_forward_result = mobile_module.forward(input, 2)
torch.testing.assert_allclose(
script_module_forward_result,
mobile_module_forward_result
)
def test_unsupported_classtype(self):
class Foo():
def __init__(self):
return
def func(self, x: int, y: int):
return x + y
class MyTestModule(torch.nn.Module):
def forward(self, arg):
f = Foo()
return f.func(1, 2)
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_typing_namedtuple(self):
myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)])
class MyTestModule(torch.nn.Module):
def forward(self):
return myNamedTuple(torch.randn(1))
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"A named tuple type is not supported in mobile module. "
r"Workaround: instead of using a named tuple type\'s fields, "
r"use a dictionary type\'s key-value pair itmes or "
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_collections_namedtuple(self):
myNamedTuple = namedtuple('myNamedTuple', [('a')])
class MyTestModule(torch.nn.Module):
def forward(self):
return myNamedTuple(torch.randn(1))
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"A named tuple type is not supported in mobile module. "
r"Workaround: instead of using a named tuple type\'s fields, "
r"use a dictionary type\'s key-value pair itmes or "
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_list_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForListWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForListWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_list: List[Foo] = [self.foo]
return my_list
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_dict_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForDictWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_dict: Dict[int, Foo] = {1: self.foo}
return my_dict
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_module_export_operator_list(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.weight = torch.ones((20, 1, 5, 5))
self.bias = torch.ones(20)
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(
input,
self.weight,
self.bias,
[1, 1],
[0, 0],
[1, 1],
False,
[0, 0],
1,
False,
False,
True,
True,
)
return (x1, x2, x3)
m = torch.jit.script(Foo())
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
expected_ops = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
}
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)
if __name__ == '__main__':
run_tests()