blob: 401cfad4d47d32f7a29a29717a3136b9ec60be83 [file] [log] [blame]
"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_functionalization_with_native_python_assertion)
"""
# Owner(s): ["module: dynamo"]
import unittest
from typing import List, Set
import operator
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing import FileCheck
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export, dynamic_dim
from torch._export.constraints import constrain_as_value, constrain_as_size
from torch._export.exported_program import ExportGraphSignature
from torch._export.passes import (
ReplaceViewOpsWithViewCopyOpsPass,
)
from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
is_view_op,
get_view_copy_of_view_op,
)
from torch._export.passes.functionalize_side_effectful_ops_pass import (
_FunctionalizeSideEffectfulOpsPass,
)
from functorch.experimental.control_flow import cond
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx._symbolic_trace import symbolic_trace
from torch.utils._pytree import tree_flatten
def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
count = 0
for node in graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
return count
class _AddOperatorSupport(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in {operator.add}
class _AtenAddOperatorSupport(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in {
torch.ops.aten.add.Tensor
}
def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]:
return [{n.name for n in p.nodes} for p in partitions]
def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
output_node = next(n for n in gm.graph.nodes if n.op == "output")
args = tree_flatten(output_node.args)[0]
# if isinstance(args, tuple) and len(args) == 1:
# args = args[0]
return [str(arg) for arg in args]
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
class TestPasses(TestCase):
def test_replace_broken_ops(self) -> None:
x = torch.randn([2, 3, 4, 5])
model: torch.nn.Linear = torch.nn.Linear(5, 5)
def f(inp: torch.Tensor) -> torch.Tensor:
return model(inp)
ep = export(f, (x,)).transform(ReplaceViewOpsWithViewCopyOpsPass())
count_after = 0
for node in ep.graph.nodes:
if node.target == torch.ops.aten.view.default:
count_after += 1
self.assertEqual(count_after, 0)
self.assertTrue(torch.allclose(ep(x), f(x)))
def test_runtime_assert_one_dim(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.cos()
x = torch.zeros(2, 2, 3)
ep = export(M(), (x,), constraints=[dynamic_dim(x, 1) >= 2, dynamic_dim(x, 1) <= 6])
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(2, 7, 3))
self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
def test_runtime_assert_multiple_dims(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x.cos().sum() + y.sin().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
constraints = [
dynamic_dim(x, 1) >= 2,
dynamic_dim(x, 1) <= 6,
dynamic_dim(y, 0) >= 3,
dynamic_dim(x, 0) >= 3
]
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
with self.assertRaisesRegex(RuntimeError, "Input arg1_1"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
def test_runtime_assert_some_dims_not_specified(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x.cos().sum() + y.sin().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
constraints = [
dynamic_dim(x, 1) >= 2,
dynamic_dim(x, 1) <= 6,
dynamic_dim(x, 0) >= 3
]
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# there are 3 asserts from y and 2 from dynamic x dims and 1 from static x dim
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
gm_result_for_1_size = ep(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
eager_result_for_1_size = M().forward(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
def test_runtime_assert_some_inps_not_used(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return y.cos().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
constraints = [
dynamic_dim(y, 1) >= 3,
dynamic_dim(y, 1) <= 6,
]
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# there are 4 asserts from y and 3 from x
self.assertEqual(num_assert, 7)
self.assertEqual(num_scalar_tensor, 7)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
gm_result_for_1_size = ep(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
eager_result_for_1_size = M().forward(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
def test_view_to_view_copy(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
z = x.view(x.shape)
return z.cos().sum()
x = torch.zeros(4, 2, 3)
ep = export(M(), (x,))
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0)
def test_functionalization_with_view_copy(self) -> None:
def foo(x):
y = x + 4
y.add_(4)
z = y.view(y.shape)
return x.cos() + z.cos()
x = torch.zeros(4, 2, 3)
ep = export(foo, (x,)).transform(ReplaceViewOpsWithViewCopyOpsPass())
# After this pass, there shouldn't be any view nodes in the graph
self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0)
self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0)
def test_views_op_having_view_copy(self) -> None:
schemas = torch._C._dispatch_get_registrations_for_dispatch_key("")
aten_schemas = [s[6:] for s in schemas if s.startswith("aten::")]
for aten_schema in aten_schemas:
val = aten_schema.split(".")
assert len(val) <= 2
name = ""
overload = ""
if len(val) == 1:
name = val[0]
overload = "default"
else:
name, overload = val[0], val[1]
op_overload = getattr(getattr(torch.ops.aten, name), overload)
if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
def test_runtime_assert_inline_constraints_for_item(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
b = x.item()
constrain_as_value(b, min=2, max=5)
return b
x = torch.tensor([2])
mod = M()
ep = export(mod, (x,))
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# 1 constraint for shape of x, 2 constraints for b
self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)
with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense_default is outside of inline constraint \[2, 5\]."):
ep(torch.tensor([6]))
new_inp = torch.tensor([5])
self.assertEqual(mod(new_inp), ep(new_inp))
def test_runtime_assert_inline_constraints_for_nonzero(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
b = x.nonzero()
constrain_as_value(b.shape[0], min=3, max=5)
return b
x = torch.tensor([2, 1, 2, 3, 5, 0])
mod = M()
ep = export(mod, (x,), constraints=[dynamic_dim(x, 0) >= 2])
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# TODO: De-duplicate assertions for same symbol.
self.assertEqual(num_assert, 4)
self.assertEqual(num_scalar_tensor, 4)
with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
ep(torch.tensor([1, 1, 0, 0, 0]))
with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
ep(torch.ones(6))
new_inp = torch.tensor([1, 1, 1, 1])
self.assertEqual(mod(new_inp), ep(new_inp))
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x, y):
def true_fn(x, y):
b = x.item()
constrain_as_value(b, min=2, max=5)
return x - b
def false_fn(x, y):
c = y.item()
constrain_as_value(c, min=2, max=5)
return y - c
ret = cond(pred, true_fn, false_fn, [x, y])
return ret
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
ep = export(mod, (torch.tensor(True), x, y))
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
def test_runtime_assert_equality_constraint(self):
class Adder(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
m = Adder()
x = torch.rand(3, 4)
y = torch.rand(3, 4)
exported = torch._export.export(
m, (x, y), constraints=[dynamic_dim(x, 1) == dynamic_dim(y, 1)]
)
x = torch.rand(3, 5)
y = torch.rand(3, 6)
with self.assertRaisesRegex(
RuntimeError,
r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]"
):
exported(x, y)
y = torch.rand(3, 5)
dynamo_result = exported(x, y)
real_result = m(x, y)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_functionalize_inline_contraints(self) -> None:
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
ep = torch._export.export(f, (torch.tensor([7]),))
gm = ep.graph_module
FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default",
1,
exactly=True,
).run(gm.code)
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
with self.assertRaisesRegex(
RuntimeError,
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
) as cm:
gm(torch.tensor([20]))
inp = torch.tensor([5])
res, dep_token = gm(inp)
self.assertEqual(res.shape, torch.Size([5, 4]))
self.assertEqual(dep_token.shape, torch.Size([]))
FileCheck().check_count(
"torch.ops.aten._functional_sym_constrain_range", 1, exactly=True
).run(gm.code)
FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default", 0, exactly=True
).run(gm.code)
dep_token_node = next(n for n in gm.graph.nodes if n.name == "dep_token3")
constrain_node = next(
n
for n in gm.graph.nodes
if n.target == torch.ops.aten._functional_sym_constrain_range
)
self.assertEqual(constrain_node.kwargs["dep_token"], dep_token_node)
def test_functionalize_input_constraints(self) -> None:
def f(x):
return x * 2
inp = torch.zeros(4, 8)
ep = torch._export.export(
f,
(inp,),
constraints=[
dynamic_dim(inp, 0) < 10,
dynamic_dim(inp, 0) >= 3,
],
)
FileCheck().check_count(
"torch.ops.aten._assert_async.msg", 3, exactly=True
).run(ep.graph_module.code)
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
with self.assertRaisesRegex(
RuntimeError,
r"Input arg0_1.shape\[0\] is outside of specified dynamic range \[3, 9\]",
):
gm(torch.ones(11, 8))
inp = torch.ones(6, 8)
self.assertEqual(gm(inp)[0], f(inp))
FileCheck().check_count(
"torch.ops.aten._functional_assert_async.msg", 3, exactly=True
).run(gm.code)
FileCheck().check_count(
"torch.ops.aten._assert_async.msg", 0, exactly=True
).run(gm.code)
def test_functionalization(self) -> None:
def f(x, y):
a = x.item()
constrain_as_size(a, 4, 7)
return x + 4, x + y * 2
inps = (torch.tensor([5]), torch.zeros((3, 4)))
ep = torch._export.export(
f,
inps,
constraints=[dynamic_dim(inps[1], 1) < 6],
_functionalize_runtime_assertions=True,
)
FileCheck().check_count(
"torch.ops.aten._functional_sym_constrain_range", 1, exactly=True
).run(ep.graph_module.code)
inps = (torch.tensor([7]), torch.ones((3, 5)))
self.assertTrue(torch._dynamo.utils.same(ep(*inps), f(*inps)))
def test_functionalization_with_native_python_assertion(self) -> None:
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b
inp = torch.Tensor([3, 4, 5])
ep = torch._export.export(f, (inp,), _functionalize_runtime_assertions=True)
# Check native assertion has corresponding functional assertion nodes generated.
select_int_node = next(
n
for n in ep.graph_module.graph.nodes
if n.target == torch.ops.aten.select.int
)
equal_scalar_node = select_int_node.next
dep_token_node = next(
n
for n in ep.graph_module.graph.nodes
if (
n.target == torch.ops.aten._functional_assert_async.msg
and n.args[0] == equal_scalar_node
)
)
self.assertIn(
"call_function[target=torch.ops.aten._functional_assert_async.msg]"
"(args = (%eq_scalar, assertion error), kwargs = {dep_token: %dep_token1}",
dep_token_node.format_node(),
)
def test_functionalization_with_mutated_buffer(self) -> None:
buf = torch.ones(6, 2)
weight = 0.01
bias = 0.2
d_in = 3
d_out = 4
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buf", buf)
self.linear = torch.nn.Linear(d_in, d_out)
self.linear.weight.data.fill_(weight)
self.linear.bias.data.fill_(bias)
def forward(self, x):
self.buf.add_(5)
return self.linear(x).cos() + self.buf.sum()
inp = torch.ones(4, 3)
ep = torch._export.export(
Foo(),
(inp,),
constraints=[dynamic_dim(inp, 0) >= 3],
_functionalize_runtime_assertions=True,
)
gs = ep.graph_signature
self.assertEqual(
gs,
ExportGraphSignature(
parameters=["L__self___linear.weight", "L__self___linear.bias"],
buffers=["L__self___buf"],
user_inputs=["arg3_1"],
user_outputs=["add_tensor_1"],
inputs_to_parameters={
"arg0_1": "L__self___linear.weight",
"arg1_1": "L__self___linear.bias",
},
inputs_to_buffers={"arg2_1": "L__self___buf"},
buffers_to_mutate={"add_tensor": "L__self___buf"},
backward_signature=None,
assertion_dep_token={2: "dep_token7"},
),
)
outputs = next(n for n in ep.graph.nodes if n.op == "output").args[0]
self.assertEqual(
[str(o) for o in outputs],
["add_tensor", "add_tensor_1", "dep_token7"],
)
self.assertEqual(
len(outputs), len(gs.buffers_to_mutate) + len(gs.user_outputs) + 1,
)
inp = torch.randn(5, 3)
self.assertTrue(
torch._dynamo.utils.same(
# Directly check run output of `ep.graph_module` which is
# functionalized.
ep.graph_module(
torch.full((d_out, d_in), weight),
torch.full((d_out,), bias),
buf.clone(),
inp,
),
(buf.add(5), Foo()(inp), torch.empty(0)),
)
)
self.assertTrue(torch._dynamo.utils.same(ep(inp), Foo()(inp)))
def test_graph_partition_after_assertion_functionalization(self) -> None:
def f1(a, b):
add = a + b
add_1 = add + b
add_2 = add_1 + add
relu_1 = add_2.relu() # blocked by this
add_3 = add_2 + relu_1
add_4 = add_2 + add_3
return add_4, add_2
partitioner1 = CapabilityBasedPartitioner(
graph_module=symbolic_trace(f1),
operator_support=_AddOperatorSupport(),
)
partitions1 = partitioner1.propose_partitions()
self.assertEqual(
_to_partition_names(partitions1),
[{"add_3", "add_4"}, {"add", "add_1", "add_2"}],
)
def f2(a, b):
add = a + b
add_1 = add + b
add_2 = add_1 + add
assert add_1[0] == 5
relu_1 = add_2.relu() # blocked by this
add_3 = add_2 + relu_1
add_4 = add_2 + add_3
return add_4, add_2
inps = (torch.tensor([1, 3, 2]), torch.tensor([2, 3, 4]))
gm = export(
f2,
inps,
constraints=[dynamic_dim(inps[0], 0) == dynamic_dim(inps[1], 0)],
_functionalize_runtime_assertions=True,
).graph_module
partitioner2 = CapabilityBasedPartitioner(
graph_module=gm,
operator_support=_AtenAddOperatorSupport(),
)
partitions2 = partitioner2.propose_partitions()
self.assertEqual(
_to_partition_names(partitions2),
[
{"add_tensor_3", "add_tensor_4"},
{"add_tensor_1", "add_tensor_2", "add_tensor"},
]
)
fused_gm1 = partitioner1.fuse_partitions(partitions1)
fused_gm2 = partitioner2.fuse_partitions(partitions2)
inps = (torch.tensor([1, 4, 6]), torch.tensor([2, 4, 6]))
self.assertTrue(
torch._dynamo.utils.same(fused_gm1(*inps)[0], fused_gm2(*inps)[0]),
)
# Sub-module `fused_1` is for logic `add = ..., ..., add_2 = ...`
output_names1 = _get_output_names(fused_gm1.get_submodule("fused_1"))
output_names2 = _get_output_names(fused_gm2.get_submodule("fused_1"))
self.assertEqual(output_names1, ["add_2"])
# The extra output `add_tensor_1` is consumed by assertion.
self.assertEqual(output_names2, ["add_tensor_1", "add_tensor_2"])
if __name__ == '__main__':
run_tests()