blob: efca721b7d4ef5daa69ca733779c4d554436b1c7 [file] [log] [blame]
# Owner(s): ["oncall: export"]
import unittest
from collections import OrderedDict
from typing import Dict, List, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._dynamo.test_case import TestCase
from torch._export.converter import TS2EPConverter
from torch.export import ExportedProgram
from torch.testing._internal.common_utils import run_tests
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
class TestConverter(TestCase):
def _check_equal_ts_ep_converter(
self,
M,
inp,
option: Union[List[str]] = None,
check_persistent=False,
lifted_tensor_constants=None,
) -> ExportedProgram:
# By default, it tests both jit.trace and jit.script.
if option is None:
option = ["trace", "script"]
if check_persistent:
num_iterations = 10
else:
num_iterations = 1
ep_list = []
for opt in option:
if opt == "script":
# Separate two models for testing non-functional effects
if check_persistent:
original_ts_model = torch.jit.script(M())
ts_model = torch.jit.script(M())
eager_model = M()
else:
original_ts_model = torch.jit.script(M)
ts_model = torch.jit.script(M)
eager_model = M
elif opt == "trace":
if check_persistent:
original_ts_model = torch.jit.trace(M(), inp)
ts_model = torch.jit.trace(M(), inp)
eager_model = M()
else:
original_ts_model = torch.jit.trace(M, inp)
ts_model = torch.jit.trace(M, inp)
eager_model = M
else:
raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
ep = TS2EPConverter(ts_model, inp).convert()
ep_list.append(ep)
for _ in range(num_iterations):
orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
# Check module.
if isinstance(eager_model, torch.nn.Module):
expected_state_dict = OrderedDict()
expected_state_dict.update(ts_model.state_dict())
if lifted_tensor_constants:
expected_state_dict.update(lifted_tensor_constants)
self.assertEqual(
ep.state_dict.keys(),
expected_state_dict.keys(),
)
# Check results
self._check_tensor_list_equal(ep_out, orig_out)
return ep_list
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
self.assertEqual(x.shape, y.shape)
self.assertTrue(torch.allclose(x, y))
else:
self.assertEqual(type(x), type(y))
self.assertEqual(x, y)
def test_ts2ep_converter_basic(self):
class MSingle(torch.nn.Module):
def forward(self, x, y):
return x + y
class MMulti(torch.nn.Module):
def forward(self, x, y):
x = x.cos() + 1
y = y.sin() - 1
return x, y
inp = (torch.ones(1, 3), torch.ones(1, 3))
self._check_equal_ts_ep_converter(MSingle(), inp)
self._check_equal_ts_ep_converter(MMulti(), inp)
def test_ts2ep_converter_container_output(self):
# Output is a List.
class MOutputList(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return [a, b]
# Output is a Tuple.
class MOutputTuple(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return (a, b)
# Output is a Dict.
class MOutputDict(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return {"data": {"mul": a, "add": b}}
inp = (torch.tensor(4), torch.tensor(4))
# Traced function must use immutable structure as output.
self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"])
self._check_equal_ts_ep_converter(MOutputTuple(), inp)
self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"])
def test_aten_dim(self):
class Module(torch.nn.Module):
def forward(self, x):
num_dim = x.dim()
return torch.ones(num_dim)
inp = (torch.ones(1, 3),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten_len(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor):
length = len(x)
return torch.ones(length)
# aten::len.Tensor
inp = (torch.ones(2, 3),)
self._check_equal_ts_ep_converter(Module(), inp)
class Module(torch.nn.Module):
def forward(self, x: List[int]):
length = len(x)
return torch.ones(length)
# aten::len.t
inp = ([1, 2, 3],)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[int, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_int
inp = ({1: "a", 2: "b", 3: "c"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[bool, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_bool
inp = ({True: "a", False: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[float, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_float
inp = ({1.2: "a", 3.4: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[torch.Tensor, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_Tensor
inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
# aten::len.str and aten::len.Dict_str are not supported
# since torch._C._jit_flatten does not support str
# inp = ("abcdefg",)
# self._check_equal_ts_ep_converter(Module(), inp)
# inp = ({"a": 1, "b": 2},)
# self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_min(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_len = len(x)
y_len = len(y)
# prim::min.int
len_int = min(x_len, y_len)
# prim::min.float
len_float = int(min(x_len * 2.0, y_len * 2.0))
# prim::min.self_int
len_self_int = min([x_len, y_len])
# prim::min.self_float
len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
# prim::min.float_int
len_float_int = int(min(x_len * 2.0, y_len))
# prim::min.int_float
len_int_float = int(min(x_len, y_len * 2.0))
return torch.ones(
len_int
+ len_float
+ len_self_int
+ len_self_float
+ len_float_int
+ len_int_float
)
inp = (torch.randn(10, 2), torch.randn(5))
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten___getitem___list(self):
class Module(torch.nn.Module):
def forward(self, x):
y = torch.split(x, 2)
return y[0]
inp = (torch.rand((3, 2)),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten___getitem___dict(self):
class Module(torch.nn.Module):
def forward(self, x):
y = torch.split(x, 2)
d_int = {0: y[0], 1: y[1]}
d_str = {"0": y[0], "1": y[1]}
d_bool = {True: y[0], False: y[1]}
d_float = {0.1: y[0], 2.3: y[1]}
return d_int[0], d_str["0"], d_bool[True], d_float[0.1]
inp = (torch.rand((3, 2)),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_device(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)
inp = (torch.rand(3, 4),)
self._check_equal_ts_ep_converter(Module(), inp)
@requires_cuda
def test_prim_device_cuda(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)
inp = (torch.rand((3, 4), device="cuda:0"),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_dtype(self):
class Module(torch.nn.Module):
def forward(self, x):
dtype = x.dtype
return torch.ones(2, 3, dtype=dtype)
for dtype in [
torch.float32,
torch.double,
]:
inp = (torch.rand((3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)
for dtype in [
torch.uint8,
torch.int8,
torch.int32,
]:
inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_convert_if_basic(self):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
if x:
return y * y
else:
return y + y
inp = (torch.tensor(True), torch.tensor(4))
ep_list = self._check_equal_ts_ep_converter(M(), inp)
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)
def test_convert_if_multiple_out(self):
class M(torch.nn.Module):
def true_fn(self, y, z):
return (z * z, z + z)
def false_fn(self, y, z):
return (y * y * y, y + y)
def forward(self, x: torch.Tensor, y: torch.Tensor):
z = y * y
if x:
res = self.true_fn(y, z)
else:
res = self.false_fn(y, z)
return res[0] + res[1]
inp = (torch.tensor(True), torch.tensor(4))
ep_list = self._check_equal_ts_ep_converter(M(), inp)
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)
def test_profiler__record_function(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
handle = torch.ops.profiler._record_function_enter_new("foo", None)
y = x * 2 + 4
torch.ops.profiler._record_function_exit(handle)
return y
x = torch.randn(10, 10)
self._check_equal_ts_ep_converter(Module(), (x,))
def test_aten_floordiv(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x // 2
x = torch.randn(10, 10)
self._check_equal_ts_ep_converter(Module(), (x,))
def test_aten___is__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return x is y, z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_aten___isnot__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return x is not y, z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_aten___not__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return not (x is not y), z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_ts2ep_converter_unpack(self):
class MUnpackList(torch.nn.Module):
def forward(self, x):
x, y = torch.split(x, 2)
return x + y
class MUnpackTuple(torch.nn.Module):
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
x, y = x_tuple
x = x.cos()
return x + y
inp = (torch.ones(4),)
self._check_equal_ts_ep_converter(MUnpackList(), inp)
inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
self._check_equal_ts_ep_converter(MUnpackTuple(), inp)
def test_convert_nn_module_with_nested_param(self):
class M(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
return self.linear(x)
class NestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
self.m = M(dim)
def forward(self, x: torch.Tensor):
return self.linear(self.m(x))
class SuperNestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
self.m = NestedM(dim)
def forward(self, x: torch.Tensor):
return self.linear(self.m(x))
inp = (torch.ones(3),)
orig_m = NestedM(3)
self._check_equal_ts_ep_converter(orig_m, inp)
orig_m = SuperNestedM(3)
self._check_equal_ts_ep_converter(orig_m, inp)
def test_convert_nn_module_with_nested_buffer(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("w", torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + x
class NestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = M()
self.register_buffer("w", torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + self.m(x)
class SuperNestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = NestedM()
self.register_buffer("w", torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + self.m(x)
inp = (torch.ones(1),)
orig_m = NestedM()
self._check_equal_ts_ep_converter(orig_m, inp)
orig_m = SuperNestedM()
self._check_equal_ts_ep_converter(orig_m, inp)
def test_convert_nn_module_with_nested_if_and_buffer(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("w", torch.randn(1))
self.count = 1
def forward(self, x: torch.Tensor):
return self.w + x + self.count
class NestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M()
self.m2 = M()
self.register_buffer("w", torch.randn(1))
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.w + self.m1(x)
else:
return self.w + self.m2(x)
# Super nested, parameters neeed to lifted
# multiple times.
class SuperNestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = NestedM()
self.m2 = NestedM()
self.register_buffer("w", torch.randn(1))
def forward(self, x: torch.Tensor):
if torch.max(x) > 1:
return self.w + self.m1(x)
else:
return self.w + self.m2(x)
# Super nested module testing.
inp = (torch.ones(1),)
orig_m = SuperNestedM()
# TODO: fix trace: state_dict is not equal.
ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"])
t = inp[0]
t -= 1
for ep in ep_list:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
def test_convert_nn_module_with_nested_if_and_param(self):
class M(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
return self.linear(x)
class NestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = M(dim)
self.m2 = M(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Super nested, parameters neeed to lifted
# multiple times.
class SuperNestedM1(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = NestedM(dim)
self.m2 = NestedM(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.max(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Super nested, even the input needs to be
# lifted recursively due to value propogation optimiztaion.
class SuperNestedM2(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = NestedM(dim)
self.m2 = NestedM(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Basic module testing.
inp = (torch.ones(3),)
orig_m = M(3)
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 0.8
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# Nested module testing.
inp = (torch.ones(3),)
orig_m = NestedM(3)
# TODO: fix trace: state_dict is not equal.
ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"])
t = inp[0]
t -= 0.8
for ep in ep_list:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# Super nested module testing.
inp = (torch.ones(3),)
orig_m = SuperNestedM1(3)
# TODO: fix trace: state_dict is not equal.
ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"])
t = inp[0]
t -= 0.8
for ep in ep_list:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# # Super nested module testing.
# inp = (torch.ones(3),)
# orig_m = SuperNestedM2(3)
# ep = self._check_equal_ts_ep_converter(orig_m, inp)
# t = inp[0]
# t -= 0.8
# torch.testing.assert_close(
# ep.module()(*inp),
# orig_m(*inp),
# )
def test_ts2ep_converter_contains(self):
class MIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [torch.float32, torch.float64]
class MNotIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [torch.int8]
class MTensorIn(torch.nn.Module):
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
return x in x_dict
# Traced function must return output that has tensors.
inp = (torch.tensor(4),)
self._check_equal_ts_ep_converter(MIn(), inp, ["script"])
self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"])
# TODO: update test to use reference for in.
inp = (torch.tensor(4), {torch.tensor(4): "foo"})
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
inp = (torch.tensor(1), {torch.tensor(4): "foo"})
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
def test_ts2ep_converter_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::foo",
"(Tensor x) -> Tensor",
lib=lib,
)
# PyTorch custorm op implementation
@torch.library.impl(
"mylib::foo",
"CompositeExplicitAutograd",
lib=lib,
)
def foo_impl(x):
return x + x
# Meta function of the custom op.
@torch.library.impl_abstract(
"mylib::foo",
lib=lib,
)
def foo_meta(x):
return x + x
class M(torch.nn.Module):
def forward(self, x):
return torch.ops.mylib.foo(x)
inp = (torch.randn(3, 3),)
m = M()
self._check_equal_ts_ep_converter(m, inp)
def test_convert_func_without_param(self):
def func1(x, y):
return x + y
def func2(x, y):
if x.sum() > 0:
return x + y
else:
return x - y
inp = (
torch.tensor(1),
torch.tensor(1),
)
self._check_equal_ts_ep_converter(func1, inp)
ep_list = self._check_equal_ts_ep_converter(func2, inp)
t = inp[0]
t -= 1
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
func2(*inp),
)
def test_implicit_constant_to_tensor_handling(self):
def func1(x):
return x + 2
def func2(x, y):
return x * y / (x - 2 * y) + y
def func3(x):
return x + torch.tensor([3])
def func4():
val = torch.tensor(float("inf"))
return torch.full((10, 10), val)
def func5():
x = -1
return x * torch.ones(1, dtype=torch.float), torch.zeros(
1, dtype=torch.float
)
def func6(x1, x2, x3, x4):
return (
x1.numel(),
x1.size(),
x2.numel(),
x2.size(),
x3.numel(),
x3.size(),
x4.numel(),
x4.size(),
torch.ones(x1.numel()), # Just make sure downstream ops still work.
torch.ones(x1.size()), # Just make sure downstream ops still work.
)
class M1(torch.nn.Module):
def __init__(self, value):
super().__init__()
self.x = torch.tensor(value)
def forward(self):
return self.x.clone()
class M2(torch.nn.Module):
def forward(self, x):
return torch.tensor(4) + x
inp = (torch.randn([2, 2]),)
self._check_equal_ts_ep_converter(func1, inp)
inp = (torch.randn([2, 2]), torch.randn([2, 2]))
self._check_equal_ts_ep_converter(func2, inp)
inp = (torch.randn([2, 2]),)
self._check_equal_ts_ep_converter(func3, inp)
self._check_equal_ts_ep_converter(func4, ())
self._check_equal_ts_ep_converter(M1(5), ())
inp = (torch.randn(2),)
self._check_equal_ts_ep_converter(M2(), inp)
self._check_equal_ts_ep_converter(func5, ())
inp = (
torch.randn([2, 3, 4]).to(torch.int8),
torch.randn([2, 3, 4]).to(torch.int32),
torch.randn([2, 3, 4]).to(torch.float32),
torch.randn([2, 3, 4]).to(torch.float64),
)
self._check_equal_ts_ep_converter(func6, inp)
def test_prim_tolist(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[int]:
return x.tolist()
inp = (torch.tensor([1, 2, 3]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[List[int]]:
return x.tolist()
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_get_tensor_constants(self):
# Since self.data is only read but not written, it is lifted as
# constant tensors.
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.data = torch.randn(3, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.data
class Goo(torch.nn.Module):
def __init__(self):
super().__init__()
self.data = torch.randn(3, 2)
self.foo = Foo()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.data + self.foo.data + self.foo(x)
inp = (torch.randn(3, 2),)
goo = Goo()
self._check_equal_ts_ep_converter(goo, inp)
def test_prim_SetAttr(self):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("data", torch.ones(3, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + x
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module, inp, ["script"], check_persistent=True
)
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("data", torch.ones(3, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + self.data
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module, inp, ["script"], check_persistent=True
)
# export lifts a tensor constant (self.data) as an input if it is not assigned.
# If it is assigned, export will error and ask users to register it as a buffer.
# In converter, we change tensor constants that are assigned as a buffer automatically,
# since it might be hard to manually register them as buffers.
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.data = torch.ones(3, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + self.data
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module,
inp,
["script"],
check_persistent=True,
lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]),
)
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.count = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.count += 1
return x + self.count
# check_persistent is False since export specializes on non-tensor constants
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module(), inp, ["script"], check_persistent=False
)
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.count = 0
def forward(self, x):
count1 = self.count
self.count += 1
count2 = self.count
self.count += 1
count3 = self.count
return x + count1 + count2 + count3
inp = (torch.ones(1),)
self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False)
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("w2", torch.ones(1))
def forward(self, x: torch.Tensor):
self.w2 += 1
return self.w2
inp = (torch.ones(1),)
self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)
if __name__ == "__main__":
run_tests()