blob: a5898506a86ae22694d92297c87376aa48d8ae19 [file] [log] [blame]
# Owner(s): ["oncall: export"]
# flake8: noqa
import unittest
from typing import Dict, List, Tuple
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._export.wrappers import _mark_strict_experimental
from torch._functorch.aot_autograd import aot_export_module
from torch.export._trace import _convert_ts_to_export_experimental
from torch.export.experimental import _export_forward_backward
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_with_buffer_as_submodule(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(3))
def forward(self, x):
y = x + 2
y.add_(4)
# this doesnt' work today with HOO
# self.buffer1.add_(6)
buffer_updated = self.buffer1 + 6
return x.sum() + y.sum() + buffer_updated.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
x_v2 = x.sin()
return (self.submodule(x_v2), x + 3)
inp = torch.randn(3)
ep = torch.export.export(M(), (inp,), strict=False)
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem_2 = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem_2, add)""",
)
self.assertExpectedInline(
str(ep.graph_module.strict_graph_0.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 2)
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
return (add_4,)""",
)
eager_mod = M()
ep = torch.export.export(eager_mod, (inp,), strict=True)
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
def test_mark_strict_with_container_type(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x0 = x[0][0]
return x0.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
return self.submodule(x)
inp = ((torch.randn(3),),)
with self.assertRaisesRegex(
RuntimeError, "strict_mode HOO doesn't work unless"
):
ep = torch.export.export(M(), inp, strict=False)
def test_torchscript_module_export(self):
class M(torch.nn.Module):
def forward(self, x):
return x.cos() + x.sin()
model_to_trace = M()
inps = (torch.randn(4, 4),)
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
def test_torchscript_module_export_single_input(self):
class M(torch.nn.Module):
def forward(self, x):
return x.cos() + x.sin()
model_to_trace = M()
inps = torch.randn(4, 4)
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
def _check_equality_and_annotations(m_func, inps):
# Original module.
model_to_trace = m_func()
# ExportedProgram from TorchScript module.
traced_module_by_torchscript = torch.jit.trace(
m_func(), example_inputs=inps
)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
# ExportedProgram from original module.
original_exported_module = torch.export.export(m_func(), inps)
# Check whether input annotations are the same as tracing the original module.
orig_ph_name_list = [
n.name
for n in original_exported_module.graph.nodes
if n.op == "placeholder"
]
ph_name_list = [
n.name for n in exported_module.graph.nodes if n.op == "placeholder"
]
self.assertEqual(orig_ph_name_list, ph_name_list)
# Check results equality.
self.assertTrue(
torch.allclose(exported_module(*inps), model_to_trace(*inps))
)
# Tuple
class MTuple(torch.nn.Module):
def forward(self, x: Tuple[torch.Tensor]):
return x[0] + x[1]
_check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
# List
class MList(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
return x[0] + x[1]
_check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
# Dict
class MDict(torch.nn.Module):
def forward(self, x: Dict[str, torch.Tensor]):
return x["0"] + x["1"]
_check_equality_and_annotations(
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
)
def test_joint_basic(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.loss(
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
)
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True)
joint_ep = _export_forward_backward(ep)
print(joint_ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"):
# No stacktrace found for following nodes
view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None
t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None
addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None
view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax)
clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None
detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None
_log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax)
mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None
div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None
ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format)
div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None
neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None
expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None
_log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None
_softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None
view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None
t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2)
mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None
t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None
sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None
t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None
return (div, t_3, view_3)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None),
InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='div'), target=None),
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='t_3'), target='linear.weight'),
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='view_3'), target='linear.bias')
]
)
Range constraints: {}
"""
def test_joint_dynamic(self) -> None:
from torch.export import Dim
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.y = torch.nn.Parameter(torch.randn(3))
def forward(self, x):
x = torch.ones(x.shape[0], 3)
return (self.y + x).sum()
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
ep = torch.export._trace._export(
m, example_inputs, pre_dispatch=True, dynamic_shapes={"x": {0: Dim("x0")}}
)
joint_ep = _export_forward_backward(ep)
if __name__ == "__main__":
run_tests()