| import io |
| import os |
| import sys |
| import typing |
| |
| import torch |
| import torch.nn as nn |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.common_utils import suppress_warnings |
| from torch.testing._internal.jit_utils import JitTestCase |
| from torch.onnx import OperatorExportTypes |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestONNXExport(JitTestCase): |
| def test_fuse_addmm(self): |
| class AddmmModel(torch.nn.Module): |
| def forward(self, x): |
| return torch.mm(x, x) + x |
| |
| x = torch.ones(3, 3) |
| f = io.BytesIO() |
| torch.onnx._export(AddmmModel(), x, f, verbose=False) |
| |
| def test_onnx_transpose_incomplete_tensor_type(self): |
| # Smoke test to get us into the state where we are attempting to export |
| # a transpose op, where the input is a TensorType without size information. |
| # This would previously not work, since we would |
| # take the size of the input and use the length of its sizes as the |
| # number of dimensions in the permutation. |
| class Foo(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x): |
| return x.contiguous().transpose(0, 1).sum() |
| |
| class TraceMe(torch.nn.Module): |
| def __init__(self): |
| super(TraceMe, self).__init__() |
| self.foo = Foo() |
| |
| def forward(self, x): |
| return self.foo(x) |
| |
| tm = TraceMe() |
| tm = torch.jit.trace(tm, torch.rand(3, 4)) |
| example_outputs = (tm(torch.rand(3, 4)),) |
| f = io.BytesIO() |
| torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs) |
| |
| def test_export_tensoroption_to(self): |
| def foo(x): |
| return x[0].clone().detach().cpu() + x |
| |
| traced = torch.jit.trace(foo, (torch.rand([2]))) |
| example_outputs = traced(torch.rand([2])) |
| |
| f = io.BytesIO() |
| torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f, |
| example_outputs=example_outputs) |
| |
| def test_onnx_export_script_module(self): |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| y = x - x |
| return x + x |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3),), None, verbose=False, |
| example_outputs=outputs) |
| |
| @suppress_warnings |
| def test_onnx_export_func_with_warnings(self): |
| @torch.jit.script |
| def func_with_warning(inp): |
| return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning |
| |
| class WarningTest(torch.nn.Module): |
| def __init__(self): |
| super(WarningTest, self).__init__() |
| |
| def forward(self, x): |
| return func_with_warning(x) |
| |
| outputs = WarningTest()(torch.randn(42)) |
| # no exception |
| torch.onnx.export_to_pretty_string( |
| WarningTest(), torch.randn(42), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_python_fail(self): |
| class PythonModule(torch.jit.ScriptModule): |
| def __init__(self): |
| super(PythonModule, self).__init__() |
| |
| @torch.jit.ignore |
| def forward(self, x): |
| return torch.neg(x) |
| |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| self.mod = PythonModule() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| y = self.mod(x) |
| return y + y |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| f = io.BytesIO() |
| with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"): |
| torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_inline_trace(self): |
| class ModuleToInline(torch.nn.Module): |
| def __init__(self): |
| super(ModuleToInline, self).__init__() |
| |
| def forward(self, x): |
| return torch.neg(x) |
| |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| y = self.mod(x) |
| return y + y |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3),), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_inline_script(self): |
| class ModuleToInline(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToInline, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return torch.neg(x) |
| |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| self.mod = ModuleToInline() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| y = self.mod(x) |
| return y + y |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3),), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_module_loop(self): |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| # test if we support end to end onnx export on loop and |
| # nested loops with and without loop index |
| for _ in range(5): |
| for i in range(3): |
| x = x + i |
| return x |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3),), None, verbose=False, |
| example_outputs=outputs) |
| |
| @suppress_warnings |
| def test_onnx_export_script_truediv(self): |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| z = x.size(0) / 2 |
| return x + z |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3)) |
| |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_non_alpha_add_sub(self): |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| bs = x.size(0) + 1 |
| return bs - 1 |
| |
| mte = ModuleToExport() |
| outputs = torch.LongTensor([mte(torch.rand(3, 4))]) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.rand(3, 4),), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_module_if(self): |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| if bool(torch.sum(x) > 0): |
| x = torch.neg(x) |
| return x |
| |
| mte = ModuleToExport() |
| outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long)) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.zeros(1, 2, 3),), None, verbose=False, |
| example_outputs=outputs) |
| |
| def test_onnx_export_script_inline_params(self): |
| class ModuleToInline(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToInline, self).__init__() |
| self.m = torch.nn.Parameter(torch.ones(3, 3)) |
| self.unused = torch.nn.Parameter(torch.ones(1, 2, 3)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return torch.mm(x, self.m) |
| |
| class ModuleToExport(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ModuleToExport, self).__init__() |
| self.mod = ModuleToInline() |
| self.param = torch.nn.Parameter(torch.ones(3, 4)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| y = self.mod(x) |
| return torch.mm(y, self.param) |
| |
| mte = ModuleToExport() |
| result = mte(torch.zeros(2, 3)) |
| reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)) |
| self.assertEqual(result, reference) |
| torch.onnx.export_to_pretty_string( |
| mte, (torch.ones(2, 3),), None, verbose=False, |
| example_outputs=result) |
| |
| def test_onnx_export_speculate(self): |
| |
| class Foo(torch.jit.ScriptModule): |
| def __init__(self, m): |
| super(Foo, self).__init__() |
| self.m = m |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| x += x |
| # because we are testing if we emit `if` statement correctly |
| # we cannot use `True` as the condition. Constant prop |
| # would remove the `if` statements. |
| c = torch.sum(x) > 4 |
| if bool(c): |
| if bool(c): |
| y = self.m(x) |
| else: |
| y = self.m(x) |
| else: |
| y = self.m(x) |
| return y |
| |
| linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float)) |
| |
| @torch.jit.script |
| def transpose(x): |
| return x.t() |
| |
| f1 = Foo(transpose) |
| outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float)) |
| f2 = Foo(linear) |
| outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float)) |
| |
| torch.onnx.export_to_pretty_string( |
| f1, |
| (torch.ones(1, 10, dtype=torch.float), ), |
| None, verbose=False, example_outputs=outputs_f1) |
| torch.onnx.export_to_pretty_string( |
| f2, |
| (torch.ones(1, 10, dtype=torch.float), ), |
| None, verbose=False, example_outputs=outputs_f2) |
| |
| def test_onnx_export_shape_reshape(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| import torch.onnx.operators |
| x = x.repeat(5, 1, 1) |
| shape = torch.onnx.operators.shape_as_tensor(x) |
| reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape) |
| return reshaped |
| |
| foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3)) |
| outputs = foo(torch.zeros(1, 2, 3)) |
| f = io.BytesIO() |
| torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f, |
| example_outputs=outputs) |
| |
| def test_listconstruct_erasure(self): |
| class FooMod(torch.nn.Module): |
| def forward(self, x): |
| mask = x < 0.0 |
| return x[mask] |
| |
| import io |
| f = io.BytesIO() |
| torch.onnx.export_to_pretty_string( |
| FooMod(), (torch.rand(3, 4),), f, |
| add_node_names=False, |
| do_constant_folding=False, |
| operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) |
| |
| |
| def test_export_dynamic_slice(self): |
| class DynamicSliceExportMod(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x): |
| retval = x[0] |
| for i in range(x.size(1)): |
| retval += torch.sum(x[0:i], dim=0) |
| return retval |
| |
| mod = DynamicSliceExportMod() |
| |
| input = torch.rand(3, 4, 5) |
| example_outs = mod(input) |
| |
| f = io.BytesIO() |
| torch.onnx.export_to_pretty_string( |
| DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10) |
| |
| def test_export_dict(self): |
| class DictModule(torch.nn.Module): |
| def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]: |
| return {"test_key_out": x_in} |
| |
| x_in = torch.tensor(1) |
| mod = DictModule() |
| mod.train(False) |
| |
| f = io.BytesIO() |
| torch.onnx.export_to_pretty_string(mod, (x_in,), f) |
| |
| with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."): |
| torch.onnx.export_to_pretty_string( |
| torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),)) |