blob: 66f1a60ca97269110547f5c2baa4a8d0bf306e70 [file] [log] [blame]
# Owner(s): ["oncall: export"]
# flake8: noqa
import unittest
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._export.wrappers import _mark_strict_experimental
from torch._functorch.aot_autograd import aot_export_module
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_with_buffer_as_submodule(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(3))
def forward(self, x):
y = x + 2
y.add_(4)
# this doesnt' work today with HOO
# self.buffer1.add_(6)
buffer_updated = self.buffer1 + 6
return x.sum() + y.sum() + buffer_updated.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
x_v2 = x.sin()
return (self.submodule(x_v2), x + 3)
inp = torch.randn(3)
ep = torch.export.export(M(), (inp,), strict=False)
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem, add)""",
)
self.assertExpectedInline(
str(ep.graph_module.strict_graph_0.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 2)
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
return (add_4,)""",
)
eager_mod = M()
ep = torch.export.export(eager_mod, (inp,), strict=True)
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
def test_mark_strict_with_container_type(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x0 = x[0][0]
return x0.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
return self.submodule(x)
inp = ((torch.randn(3),),)
with self.assertRaisesRegex(
RuntimeError, "strict_mode HOO doesn't work unless"
):
ep = torch.export.export(M(), inp, strict=False)
if __name__ == "__main__":
run_tests()