blob: 88e1c48dfa7e4ea99230126df3959e32a52af03d [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import copy
from typing import Tuple
import unittest
import torch # noqa: F401
import torch.nn as nn
import torch._dynamo as torchdynamo
from functorch import make_fx
from functorch.experimental import functionalize
from torch import Tensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.verifier import (
SpecViolationError,
Verifier,
ATenDialectVerifier,
)
@torch.no_grad()
def capture(f, args):
torchdynamo.config.allow_rnn = True
torchdynamo.reset()
graphmodule, _ = torchdynamo.export(
f,
*copy.deepcopy(args),
aten_graph=True,
)
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graphmodule).run(*args)
functionalized_callable = functionalize(
graph_with_interpreter,
remove='mutations_and_views',
)
gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args)
return gm
class Transpose(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
return x.transpose(dim0, dim1)
class Mul(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input: Tensor, other: Tensor) -> Tensor:
# or return torch.mul(input, other)
return input * other
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
return (torch.randn(3, 2), torch.randn(3, 2))
class ElementwiseAdd(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
return (torch.randn(1, 3), torch.randn(1, 3))
class Cat(nn.Module):
def __init__(self) -> None:
super().__init__()
# def forward(self, tensors, dim=0):
def forward(self, *args: Tensor, dim: int) -> Tensor:
tensors = args[:-1]
return torch.cat(tensors, dim)
class FeedForwardBlock(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int) -> None:
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.layer_norm = nn.LayerNorm(input_dim)
self.relu = nn.ReLU()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.dropout1 = nn.Dropout()
self.linear2 = nn.Linear(hidden_dim, input_dim)
self.dropout2 = nn.Dropout()
def forward(self, x: Tensor) -> Tensor:
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
y = self.layer_norm(x)
y = self.linear1(y)
y = self.dropout1(y)
y = self.relu(y)
y = self.linear2(y)
y = self.dropout2(y)
return y
class VerifierTest(TestCase):
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_verifier(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
# assert not throw
verifier = Verifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_verifier_call_module(self) -> None:
m = FeedForwardBlock(10, 10)
gm = torch.fx.symbolic_trace(m)
# this would have modules that are not delegates
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(gm)
self.assertFalse(verifier.is_valid(gm))
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_verifier_no_functional(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
for node in egm.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add.out
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_aten_dialect(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
verifier = ATenDialectVerifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_aten_wrong_mem_format(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.parameter.Parameter(
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last)
)
def forward(self, x):
return self.a + x
m = TestModel()
egm = capture(m, (torch.randn(1, 3, 100, 100),))
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
verifier = ATenDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_aten_wrong_mem_format_buffer(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"a",
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
)
def forward(self, x):
return self.a + x
m = TestModel()
egm = capture(m, (torch.randn(1, 3, 100, 100),))
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
verifier = ATenDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
def test_aten_wrong_op(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.ops.aten._add_relu(x, x)
m = TestModel()
egm = torch.fx.symbolic_trace(m)
verifier = ATenDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
if __name__ == '__main__':
run_tests()