blob: 82c7ab118cb3d2e797299a78adaa3b7867e4ba82 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import copy
import unittest
from typing import Dict, List, Tuple
import executorch.exir as exir
import executorch.exir.tests.models as models
import torch
from executorch.exir import CaptureConfig
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo
from functorch.experimental.control_flow import cond, map
from parameterized import parameterized
from torch._export.verifier import SpecViolationError
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from torch.testing import FileCheck
class TestTorchDispatchFXTracer(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
register_additional_test_aten_ops()
def test_simple(self) -> None:
f = models.BasicSinMax()
f = (
exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code)
def test_static_control_flow(self) -> None:
def f(pred: bool, x: torch.Tensor) -> torch.Tensor:
if pred:
return torch.sin(x).max()
else:
return torch.sin(x)
pred = True
x = torch.randn(100)
f_true = (
exir.capture(f, (pred, x), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check(
"executorch_exir_dialects_edge__ops_aten_max"
).run(f_true.code)
pred = False
f_false = (
exir.capture(f, (pred, x), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not(
"executorch_exir_dialects_edge__ops_aten_max"
).run(f_false.code)
def test_copy(self) -> None:
f = models.BasicSinMax()
f = (
exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
self.assertTrue(isinstance(f, torch.fx.GraphModule))
g = copy.deepcopy(f)
self.assertTrue(isinstance(g, torch.fx.GraphModule))
def test_stacktrace(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return x + x
traced_f = (
exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
# Check that stacktrace is populated and retained (by checking twice)
self.assertTrue(
any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
)
self.assertTrue(
any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
)
def test_possible_input_mutation(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return torch.add(torch.ones(5), torch.ones(5), out=x)
with self.assertRaisesRegex(
SpecViolationError,
r"operator .* is not functional",
):
exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge()
def test_tensor_spec_for_const_tensors(self) -> None:
class Module(torch.nn.Module):
def __init__(self) -> None:
super(Module, self).__init__()
self.linear = torch.nn.Linear(2, 3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.randn(2),)
model = Module()
graph_module = (
exir.capture(model, model.get_random_inputs(), exir.CaptureConfig())
# torch._ops.aten.t.default
.to_edge(
exir.EdgeCompileConfig(_check_ir_validity=False)
).exported_program.graph_module
)
num_get_attr_node = 0
num_get_attr_node_with_tensorspec = 0
for nd in graph_module.graph.nodes:
if nd.op == "get_attr":
num_get_attr_node += 1
if nd.meta.get("val") is not None:
num_get_attr_node_with_tensorspec += 1
self.assertEqual(2, num_get_attr_node)
self.assertEqual(2, num_get_attr_node_with_tensorspec)
def test_multiple_returns_spec(self) -> None:
def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.aten.max.dim(x, 0, False)
cnt = 0
module = (
exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
for node in module.graph.nodes:
if node.target == exir_ops.edge.aten.max.dim:
cnt += 1
self.assertIsInstance(node.meta["val"], tuple)
self.assertEqual(cnt, 1)
def test_multiple_returns_pt2_mode(self) -> None:
def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
a = x * x
b = x + a
return a, b
inputs = (torch.ones(1, 2, 3),)
orig_res = f(*inputs)
module = (
exir.capture(
f,
inputs,
exir.CaptureConfig(),
)
.to_edge()
.exported_program.graph_module
)
new_res = module(*inputs)
for node in module.graph.nodes:
if node.op == "output":
self.assertIsInstance(node.meta["val"], list)
self.assertEqual(len(node.meta["val"]), 2)
self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
self.assertTrue(torch.allclose(orig_res[1], new_res[1]))
def test_dynamo_capture_scalar_outputs(self) -> None:
def f(x: torch.Tensor) -> float:
return x.item()
gm, guards = dynamo_trace(
f,
(torch.ones(1),),
False,
"real",
ExirDynamoConfig(),
)
# pyre-ignore
@parameterized.expand([("stock_tensor",)])
def test_embedding_dynamic_shape(self, input_type: str) -> None:
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + x
example_input = torch.ones(10, dtype=torch.int64)
m = Module()
gm = (
exir.capture(
m,
(example_input,),
exir.CaptureConfig(
enable_functionalization=False,
enable_dynamic_shape=True,
),
)
.to_edge()
.exported_program.graph_module
)
print(gm.graph)
def test_dynamic_shape(self) -> None:
def forward(x: torch.Tensor) -> torch.Tensor:
x = x.view(x.shape[0] - 1, -1)
return torch.cat([x, x])
gm = (
exir.capture(
forward,
(torch.ones(3, 2, dtype=torch.int64),),
exir.CaptureConfig(
enable_functionalization=False,
enable_dynamic_shape=True,
_dynamo_config=ExirDynamoConfig(assume_static_by_default=True),
),
# sym_size is not reg op
)
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.exported_program.graph_module
)
for node in gm.graph.nodes:
if node.op in ("placeholder", "call_function"):
self.assertIn("val", node.meta)
def test_dynamo_frontend_container_input(self) -> None:
class Module(torch.nn.Module):
def __init__(self) -> None:
super(Module, self).__init__()
def forward(
self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
) -> torch.Tensor:
a = x[0]
b = x[1]
cum = 0
for i in b:
cum += i.sum()
return a.cos() + cum.sin()
with using_dynamo(True):
inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),)
gm = exir.capture(Module(), inp, exir.CaptureConfig())
self.assertTrue(torch.allclose(Module()(*inp), gm(*inp)))
# TODO (tmanlaibaatar) remove this test
def test_pt2_mode_with_dynamo_config(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return x[: x.shape[0] - 1]
inp = (torch.randn(4, 5),)
prog = exir.capture(
f,
inp,
# missing dispatch key
).to_edge()
self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3)
def test_input_container_type(self) -> None:
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
# pyre-ignore
return {"a": x.sum() + sum(y).sum()}
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
# pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule,
# Set[torch._guards.Guard]]` into 2 values.
gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic")
prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge()
self.assertEqual(prog(*inp), f(*inp))
def test_aot_buffer_mutation(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"_bin_num_examples",
torch.empty([42]).fill_(
0.0,
),
)
def forward(self, x, y, z):
self._bin_num_examples.index_copy_(
dim=0,
index=y,
source=z,
)
self._bin_num_examples.index_add_(
dim=0, index=torch.arange(4), source=x
)
return self._bin_num_examples - 1, x * z
model = Module()
example_inputs = (
torch.randn(4, requires_grad=True),
torch.tensor(0),
torch.tensor(3.14),
)
with self.assertRaisesRegex(
RuntimeError,
"Found a graph input that requires gradients, and received a mutation.",
):
_ = exir.capture(
model,
example_inputs,
exir.CaptureConfig(
enable_aot=True,
),
)
# Note that model._bin_num_examples is mutated during exir.capture
# We need to create a new_model
new_model = Module()
example_inputs = (
torch.randn(4),
torch.tensor(0),
torch.tensor(3.14),
)
ep = exir.capture(
new_model,
example_inputs,
exir.CaptureConfig(
enable_aot=True,
),
)
test_inputs = (
torch.randn(4),
torch.tensor(0),
torch.tensor(2.1),
)
graph_outputs = ep(*test_inputs)
eager_outputs = Module()(*test_inputs)
self.assertEqual(len(graph_outputs), 2)
self.assertEqual(len(eager_outputs), 2)
self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0]))
self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1]))
def test_assume_constant_by_default_prop(self) -> None:
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if x.shape[0] > 3:
return x.cos()
return x.sin()
dynamo_config = ExirDynamoConfig(assume_static_by_default=True)
capture_config = exir.CaptureConfig(
enable_dynamic_shape=True, _dynamo_config=dynamo_config
)
captured = exir.capture(
foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config
).exported_program.graph_module
found = False
for node in captured.graph.nodes:
# at least one input needs to have concrete dims
if "val" in node.meta:
fake_val = node.meta["val"]
for dim in fake_val.shape:
if is_concrete_int(dim):
found = True
self.assertTrue(found)
def test_aot_config(self) -> None:
class FooWithBuffer(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.zeros(42))
def forward(self, x):
return x.cos() + self.buffer.sum()
capture_config = exir.CaptureConfig(enable_aot=True)
captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
captured_gm = captured_ep.exported_program.graph_module
placeholder_nodes = set()
print(captured_gm.graph)
for node in captured_gm.graph.nodes:
self.assertFalse(node.op == "get_attr")
if node.op == "placeholder":
placeholder_nodes.add(node)
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
# make sure the placeholders are used
arg_0, arg_1 = node.args
self.assertEqual(
placeholder_nodes,
{
list(arg_0._input_nodes.keys())[0],
list(arg_1._input_nodes.keys())[0],
},
)
self.assertEqual(len(placeholder_nodes), 2)
captured_ep.to_edge()
def test_export_unlift(self) -> None:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self, x):
return x.cos() + self.buffer.sin()
ep = exir.capture(
Foo(),
(torch.ones(6, 4),),
exir.CaptureConfig(enable_aot=True, _unlift=True),
)
self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
def test_export_container_unlift(self) -> None:
class FooContainerInputOutput(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self, x):
return x[0][0].cos() + x[0][1].sin() + self.buffer.sin()
inp = ((torch.ones(6, 4), torch.ones(6, 4)),)
ep = exir.capture(
FooContainerInputOutput(),
(inp,),
CaptureConfig(enable_aot=True, _unlift=True),
)
self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp)))
def test_export_container_input_unlift(self) -> None:
class FooContainerInputOutputV2(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self, x, y):
return x[0].cos() + y[0].sin() + self.buffer.sin()
inp = ((torch.ones(6, 4),), (torch.ones(6, 4),))
ep = exir.capture(
FooContainerInputOutputV2(),
inp,
CaptureConfig(enable_aot=True, _unlift=True),
)
self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp)))
def test_export_cond(self) -> None:
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self):
return self.buffer.cos()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
def forward(self, x):
def true_fn(x):
return x.cos() + self.a().sum()
def false_fn(x):
return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
inp = torch.ones(6, 4)
ep = exir.capture(
Foo(),
(inp,),
CaptureConfig(enable_aot=True, _unlift=True),
)
self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
def test_export_cond_map(self) -> None:
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self):
return self.buffer.sum()
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
def inner(self, x, pred):
def true_fn(x):
return x + x + self.a()
def false_fn(x):
return x * x - self.a()
return cond(pred, true_fn, false_fn, [x])
def forward(self, pred, xs):
def body(x, pred):
return self.inner(x, pred) + self.a()
return map(body, xs, pred)
inp = torch.randn(3, 2, 1)
ep = exir.capture(
Module(),
(torch.tensor(True), inp),
CaptureConfig(enable_aot=True, _unlift=True),
)
inp_test = torch.randn(3, 2, 1)
self.assertTrue(
torch.allclose(
ep(torch.tensor(True), inp_test),
Module()(torch.tensor(True), inp_test),
)
)
self.assertTrue(
torch.allclose(
ep(torch.tensor(False), inp_test),
Module()(torch.tensor(False), inp_test),
)
)