blob: 21ee19d262aab1709e66a62a20bcd9289aab7026 [file] [log] [blame]
import io
import os
import sys
import typing
import torch
import torch.nn as nn
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase
from torch.onnx import OperatorExportTypes
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestONNXExport(JitTestCase):
def test_fuse_addmm(self):
class AddmmModel(torch.nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
f = io.BytesIO()
torch.onnx._export(AddmmModel(), x, f, verbose=False)
def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
# a transpose op, where the input is a TensorType without size information.
# This would previously not work, since we would
# take the size of the input and use the length of its sizes as the
# number of dimensions in the permutation.
class Foo(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.contiguous().transpose(0, 1).sum()
class TraceMe(torch.nn.Module):
def __init__(self):
super(TraceMe, self).__init__()
self.foo = Foo()
def forward(self, x):
return self.foo(x)
tm = TraceMe()
tm = torch.jit.trace(tm, torch.rand(3, 4))
example_outputs = (tm(torch.rand(3, 4)),)
f = io.BytesIO()
torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
def test_export_tensoroption_to(self):
def foo(x):
return x[0].clone().detach().cpu() + x
traced = torch.jit.trace(foo, (torch.rand([2])))
example_outputs = traced(torch.rand([2]))
f = io.BytesIO()
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
example_outputs=example_outputs)
def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
y = x - x
return x + x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
@suppress_warnings
def test_onnx_export_func_with_warnings(self):
@torch.jit.script
def func_with_warning(inp):
return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning
class WarningTest(torch.nn.Module):
def __init__(self):
super(WarningTest, self).__init__()
def forward(self, x):
return func_with_warning(x)
outputs = WarningTest()(torch.randn(42))
# no exception
torch.onnx.export_to_pretty_string(
WarningTest(), torch.randn(42), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
def __init__(self):
super(PythonModule, self).__init__()
@torch.jit.ignore
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
self.mod = PythonModule()
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
def __init__(self):
super(ModuleToInline, self).__init__()
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToInline, self).__init__()
@torch.jit.script_method
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
self.mod = ModuleToInline()
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
# test if we support end to end onnx export on loop and
# nested loops with and without loop index
for _ in range(5):
for i in range(3):
x = x + i
return x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
@suppress_warnings
def test_onnx_export_script_truediv(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
z = x.size(0) / 2
return x + z
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
bs = x.size(0) + 1
return bs - 1
mte = ModuleToExport()
outputs = torch.LongTensor([mte(torch.rand(3, 4))])
torch.onnx.export_to_pretty_string(
mte, (torch.rand(3, 4),), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
if bool(torch.sum(x) > 0):
x = torch.neg(x)
return x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToInline, self).__init__()
self.m = torch.nn.Parameter(torch.ones(3, 3))
self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
@torch.jit.script_method
def forward(self, x):
return torch.mm(x, self.m)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
super(ModuleToExport, self).__init__()
self.mod = ModuleToInline()
self.param = torch.nn.Parameter(torch.ones(3, 4))
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return torch.mm(y, self.param)
mte = ModuleToExport()
result = mte(torch.zeros(2, 3))
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
self.assertEqual(result, reference)
torch.onnx.export_to_pretty_string(
mte, (torch.ones(2, 3),), None, verbose=False,
example_outputs=result)
def test_onnx_export_speculate(self):
class Foo(torch.jit.ScriptModule):
def __init__(self, m):
super(Foo, self).__init__()
self.m = m
@torch.jit.script_method
def forward(self, x):
x += x
# because we are testing if we emit `if` statement correctly
# we cannot use `True` as the condition. Constant prop
# would remove the `if` statements.
c = torch.sum(x) > 4
if bool(c):
if bool(c):
y = self.m(x)
else:
y = self.m(x)
else:
y = self.m(x)
return y
linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
@torch.jit.script
def transpose(x):
return x.t()
f1 = Foo(transpose)
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
f2 = Foo(linear)
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
torch.onnx.export_to_pretty_string(
f1,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f1)
torch.onnx.export_to_pretty_string(
f2,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f2)
def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
def forward(self, x):
import torch.onnx.operators
x = x.repeat(5, 1, 1)
shape = torch.onnx.operators.shape_as_tensor(x)
reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
return reshaped
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
outputs = foo(torch.zeros(1, 2, 3))
f = io.BytesIO()
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
example_outputs=outputs)
def test_listconstruct_erasure(self):
class FooMod(torch.nn.Module):
def forward(self, x):
mask = x < 0.0
return x[mask]
import io
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
FooMod(), (torch.rand(3, 4),), f,
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
def test_export_dynamic_slice(self):
class DynamicSliceExportMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
retval = x[0]
for i in range(x.size(1)):
retval += torch.sum(x[0:i], dim=0)
return retval
mod = DynamicSliceExportMod()
input = torch.rand(3, 4, 5)
example_outs = mod(input)
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)
def test_export_dict(self):
class DictModule(torch.nn.Module):
def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
return {"test_key_out": x_in}
x_in = torch.tensor(1)
mod = DictModule()
mod.train(False)
f = io.BytesIO()
torch.onnx.export_to_pretty_string(mod, (x_in,), f)
with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(
torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))