blob: acddc0230a124b3d05222a62b00393f248f84959 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
import torch
from functorch.experimental import control_flow
from torch import Tensor
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export
from torch._export.verifier import ATenDialectVerifier, SpecViolationError, Verifier
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.testing._internal.common_utils import run_tests, TestCase
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
class TestVerifier(TestCase):
def test_verifier_basic(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
ep = export(f, (torch.randn(100), torch.randn(100)))
verifier = Verifier()
verifier(ep.graph_module)
def test_verifier_call_module(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
gm = torch.fx.symbolic_trace(M())
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(gm)
def test_verifier_no_functional(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
ep = export(f, (torch.randn(100), torch.randn(100)))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add_.Tensor
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(ep.graph_module)
def test_verifier_higher_order(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(
x.shape[0] > 2, true_fn, false_fn, [x, y]
)
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
verifier = Verifier()
verifier(ep.graph_module)
def test_verifier_nested_invalid_module(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(
x.shape[0] > 2, true_fn, false_fn, [x, y]
)
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
for node in ep.graph_module.true_graph_0.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add_.Tensor
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(ep.graph_module)
def test_aten_verifier_wrong_op(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.ops.aten._add_relu(x, x)
m = TestModel()
egm = torch.fx.symbolic_trace(m)
verifier = ATenDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
def test_ep_verifier_basic(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
ep = export(M(), (torch.randn(10, 10),))
ep._validate()
def test_ep_verifier_invalid_param(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
ep = export(f, (torch.randn(100), torch.randn(100)))
# Parameter doesn't exist in the state dict
ep.graph_signature.input_specs.insert(
0,
InputSpec(
kind=InputKind.PARAMETER,
arg=TensorArgument(name="arg0_1"),
target="bad_param"
)
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep._validate()
# Add non-torch.nn.Parameter parameter to the state dict
ep.state_dict["bad_param"] = torch.randn(100)
with self.assertRaisesRegex(
SpecViolationError, "not an instance of torch.nn.Parameter"
):
ep._validate()
# Add torch.nn.Parameter to state dict, but this should still error
# because there are an incorrect number of placeholder nodes
ep.state_dict["bad_param"] = torch.nn.Parameter(torch.randn(100))
with self.assertRaisesRegex(
SpecViolationError, "the number of inputs specified by the graph signature"
):
ep._validate()
def test_ep_verifier_invalid_buffer(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
ep = export(f, (torch.randn(100), torch.randn(100)))
# Buffer doesn't exist in the state dict
ep.graph_signature.input_specs.insert(
0,
InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name="arg0_1"),
target="bad_buffer"
)
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep._validate()
# Incorrect number of placeholder nodes
ep.state_dict["bad_buffer"] = torch.randn(100)
with self.assertRaisesRegex(
SpecViolationError, "the number of inputs specified by the graph signature"
):
ep._validate()
def test_ep_verifier_buffer_mutate(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
ep._validate()
def test_ep_verifier_invalid_output(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
output_node = list(ep.graph.nodes)[-1]
with ep.graph.inserting_before(output_node):
additional_output_node = ep.graph.call_function(
torch.add, args=(output_node.args[0][0], output_node.args[0][0])
)
output_node.args = (
(
output_node.args[0][0],
additional_output_node,
output_node.args[0][1],
),
)
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
ep._validate()
if __name__ == '__main__':
run_tests()