blob: 6baa41468c7c1905fad18ff23f944aa0989eb525 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo as torchdynamo
from torch._export import dynamic_dim, export
from torch._export.constraints import constrain_as_size
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import all_examples
from torch._export.serde.serialize import (
ExportedProgramDeserializer,
ExportedProgramSerializer,
deserialize,
serialize,
SerializeError,
)
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
import torch.utils._pytree as pytree
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
def get_filtered_export_db_tests():
unsupported_tags = {"torch.cond", "torch.map"}
unsupported_test_names = {
"dynamic_shape_constructor", # 'NoneType' object has no attribute 'from_tensor'
"dictionary", # Graph output must be a tuple()
"fn_with_kwargs", # export doesn't support kwargs yet
"scalar_output", # Tracing through 'f' must produce a single graph
}
return [
(name, case)
for name, case in all_examples().items()
if (
case.support_level == SupportLevel.SUPPORTED and
not (unsupported_tags & case.tags) and
name not in unsupported_test_names
)
]
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestSerialize(TestCase):
def test_serialize_multiple_returns_from_node(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w, b):
return torch.nn.functional.layer_norm(
x,
x.size()[1:],
weight=w,
bias=b,
eps=1e-5,
)
exported_module = export(
MyModule(),
(
torch.ones([512, 512], requires_grad=True),
torch.ones([512]),
torch.ones([512]),
),
)
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default")
# aten::native_layer_norm returns 3 tensnors
self.assertEqual(len(node.outputs), 3)
# check the names are unique
seen = set()
for output in node.outputs:
name = output.as_tensor.name
self.assertNotIn(name, seen)
seen.add(name)
def test_serialize_list_returns(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.split(x, 2)
input = torch.arange(10.0).reshape(5, 2)
input.requires_grad = True
exported_module = export(MyModule(), (input,))
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch.ops.aten.split.Tensor")
self.assertEqual(len(node.outputs), 1)
# Input looks like:
# tensor([[0, 1],
# [2, 3],
# [4, 5],
# [6, 7],
# [8, 9]])
# Output looks like:
# (tensor([[0, 1],
# [2, 3]]),
# tensor([[4, 5],
# [6, 7]]),
# tensor([[8, 9]]))
self.assertEqual(len(node.outputs[0].as_tensors), 3)
# check the names are unique
seen = set()
for output in node.outputs[0].as_tensors:
name = output.name
self.assertNotIn(name, seen)
seen.add(name)
def test_multi_return_some_unused(self) -> None:
"""
Make sure the serialized output matches the op schema, even if some of
the arguments are never used in the graph.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.ops.aten.var_mean.correction(x, [1])[0]
exported_module = export(
MyModule(),
(torch.ones([512, 512], requires_grad=True),),
)
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
self.assertEqual(len(node.outputs), 2)
# check the names are unique
seen = set()
for output in node.outputs:
name = output.as_tensor.name
self.assertNotIn(name, seen)
seen.add(name)
def test_kwargs_default(self) -> None:
"""
Tests that the kwargs default values are serialized even if they are not
specified
"""
def f(x: torch.Tensor) -> torch.Tensor:
values = torch.randn(3, 2)
return torch.searchsorted(x, values, side="right", right=True)
x, _ = torch.sort(torch.randn(3, 4))
exported_module = export(f, (x,))
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor")
self.assertEqual(len(node.inputs), 6)
self.assertEqual(node.inputs[2].arg.as_bool, False)
self.assertEqual(node.inputs[3].arg.as_bool, True)
self.assertEqual(node.inputs[4].arg.as_string, "right")
self.assertEqual(node.inputs[5].arg.as_none, ())
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
def check_graph(self, fn, inputs, constraints=None) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
# TODO(angelayi): test better with some sort of wrapper
constraints = [] if constraints is None else constraints
ep = export(fn, inputs, {}, constraints)
ep.graph.eliminate_dead_code()
serialized_struct, state_dict = serialize(ep, opset_version={"aten": 0})
deserialized_ep = deserialize(serialized_struct, state_dict, expected_opset_version={"aten": 0})
deserialized_ep.graph.eliminate_dead_code()
orig_outputs = ep(*inputs)
loaded_outputs = deserialized_ep(*inputs)
flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertEqual(type(orig), type(loaded))
if isinstance(orig, torch.Tensor):
self.assertTrue(torch.allclose(orig, loaded))
else:
self.assertEqual(orig, loaded)
self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes))
for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
self.assertEqual(node1.op, node2.op)
if node1.op == "call_function":
# Check "val" metadata
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)
if val1 is None or val2 is None:
# Either both are None
self.assertEqual(val1, val2)
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
# Or both are fake tensors with the same shape/dtype
self.assertEqual(len(val1.shape), len(val2.shape))
for s1, s2 in zip(val1.shape, val2.shape):
if is_concrete_int(s1) and is_concrete_int(s2):
self.assertEqual(s1, s2)
else:
self.assertEqual(str(s1), str(s2))
self.assertEqual(val1.dtype, val2.dtype)
elif isinstance(val1, list) and isinstance(val2, list):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
self.assertTrue(len(val1) == 1 and len(val2) == 1)
self.assertEqual(val1[0].shape, val2[0].shape)
self.assertEqual(val1[0].dtype, val2[0].dtype)
else:
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))
# Check "stack_trace" metadata
self.assertEqual(
node1.meta.get("stack_trace", None),
node2.meta.get("stack_trace", None),
)
if node1.op != "get_attr" and node1.op != "placeholder":
# Check "nn_module_stack" metadata
self.assertEqual(
node1.meta.get("nn_module_stack", None),
node2.meta.get("nn_module_stack", None),
)
# Check "source_fn" metadata
self.assertEqual(
node1.meta.get("source_fn", None),
node2.meta.get("source_fn", None),
)
def test_multi_return(self) -> None:
"""
Test multiple return from a single node (ex. layer_norm has 2 outputs)
"""
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w, b):
return torch.nn.functional.layer_norm(
x,
x.size()[1:],
weight=w,
bias=b,
eps=1e-5,
)
inputs = (
torch.ones([512, 512], requires_grad=True),
torch.ones([512]),
torch.ones([512]),
)
self.check_graph(MyModule(), inputs)
def test_basic(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x + x
x = x * x
x = x / x
return x, x.clone()
inputs = (torch.ones([512], requires_grad=True),)
self.check_graph(MyModule(), inputs)
def test_dynamic(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[2], 0),
dynamic_dim(inputs[2], 0) == dynamic_dim(inputs[0], 0),
]
self.check_graph(DynamicShapeSimpleModel(), inputs, constraints)
def test_sym_bool(self):
def f(x, y):
return x.size(0) in y
self.check_graph(f, (torch.ones(2), torch.ones(3)))
def test_shape(self):
def f(x):
z, y = x.size()
return z + y + x[0], z
inputs = (torch.ones(2, 3),)
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[0], 1),
]
self.check_graph(f, inputs, constraints)
def test_module(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x
inputs = (torch.randn(3, 3),)
self.check_graph(M(), inputs)
def test_cond(self):
from functorch.experimental.control_flow import cond
inputs = torch.ones(4, 3), torch.zeros(4, 3)
class M(torch.nn.Module):
def forward(self, x, y):
def t(x, y):
return x + y
def f(x, y):
return x - y
return cond(x[0][0] > 4, t, f, [x, y])
self.check_graph(M(), inputs)
@parametrize(
"name,case",
get_filtered_export_db_tests(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
self.check_graph(model, inputs.args)
def test_constraints(self):
def f(x, y):
n = x.item()
constrain_as_size(n, min=2)
return y.sum() + torch.ones(n, 5).sum()
self.check_graph(f, (torch.tensor(3), torch.randn(4, 5)))
instantiate_parametrized_tests(TestDeserialize)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestSchemaVersioning(TestCase):
def test_error(self):
def f(x):
return x + x
ep = export(f, (torch.randn(1, 3),))
serialized_ep, serialized_state_dict = ExportedProgramSerializer().serialize(ep)
serialized_ep.schema_version = -1
with self.assertRaisesRegex(SerializeError, r"Serialized schema version -1 does not match our current"):
ExportedProgramDeserializer().deserialize(serialized_ep, serialized_state_dict)
class TestOpVersioning(TestCase):
"""Test if serializer/deserializer behaves correctly if version mismatch."""
def test_empty_model_opset_version_raises(self):
compiler_opset_version = {"aten": 4}
model_opset_version = None
deserializer = ExportedProgramDeserializer(compiler_opset_version)
with self.assertRaises(RuntimeError):
deserializer._validate_model_opset_version(model_opset_version)
def test_opset_mismatch_raises(self):
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
deserializer = ExportedProgramDeserializer(compiler_opset_version)
with self.assertRaises(NotImplementedError):
deserializer._validate_model_opset_version(model_opset_version)
def test_model_op_namespace_version_missing_from_deserializer_do_not_raises(self):
compiler_opset_version = {"aten": 3}
model_opset_version = {"aten": 3, "custom": 4}
deserializer = ExportedProgramDeserializer(compiler_opset_version)
with self.assertLogs(level='WARN') as log:
deserializer._validate_model_opset_version(model_opset_version)
self.assertIn("Compiler doesn't have a version table for op namespace", log.output[0])
unittest.expectedFailure(
TestDeserialize.test_exportdb_supported_case_tensor_setattr
)
unittest.expectedFailure(
TestDeserialize.test_exportdb_supported_case_pytree_flatten
)
if __name__ == '__main__':
run_tests()