blob: 89d34ea4f2b639423c5de3afaac1cd5e93d9c0f5 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_export_persist_assert)
"""
import copy
import functools
import inspect
import io
import operator
import unittest
from enum import Enum
from typing import Dict, List, Sequence
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._dynamo.exc import UserError
from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses import fake_tensor
from torch.export import dynamic_dim
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import TEST_CUDA
class ExportTests(torch._dynamo.test_case.TestCase):
# TODO(voz): Refactor to a shared test function.
# The tests in this file are a little redundant,
# They all take a func, run it with eager, then export it, then compare
def test_export(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func)()
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_shape_control_flow_1(self):
def func(x):
if x.shape[0] > 10:
return x.cos()
return x.sin()
opt_func = torch._dynamo.optimize("eager")(func)
real_result = opt_func(torch.ones(6, 4))
torch._dynamo.reset()
exported = torch._dynamo.export(func)(torch.ones(6, 4))
out_graph, out_guards = exported
dynamo_result = out_graph(torch.ones(6, 4))
from torch._guards import GuardSource
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
hit = False
for guard in out_guards:
if guard.source == GuardSource.SHAPE_ENV:
hit = True
self.assertExpectedInline(
guard.code_list,
"""["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950
)
break
self.assertTrue(hit)
def test_export_control_flow_with_getattr(self):
class Animal(Enum):
COW = "moo"
class MyModule(torch.nn.Module):
def __init__(self, a):
super().__init__()
self.a = a
def forward(self, x):
if self.a == Animal.COW.value:
return x * x
else:
raise ValueError("bad")
module = MyModule("moo")
input = (torch.ones(4, 3),)
resA = module(*input)
graph, _ = torch._dynamo.export(module)(*input)
resB = graph(*input)
self.assertTrue(torch._dynamo.utils.same(resA, resB))
def test_export_graph_bypass(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_shallow_list_copy_wo_side_effects(self):
def f(x):
y = x.copy()
return y[0] + y[1]
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
inp
).graph_module
self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp)))
def test_export_with_shallow_list_copy_with_side_effects(self):
def f(x):
y = x.copy()
x[0] = x[1]
y.append(torch.tensor([[100]]))
return x[0] + x[1], y[0] + y[1], y[2]
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
inp
).graph_module
res = gm(inp)
ref = f(inp)
self.assertTrue(torch._dynamo.utils.same(res, ref))
self.assertEqual(res[0], res[1])
def test_export_mismatched_out_2(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_dupes_and_bypass_with_non_tensor_output(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_zeroes_in_new_shape_scalar_out(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a[0].item() + b[0].item() + c[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_zeroes_in_new_shape_scalar_out_permute(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return b[0].item() + c[0].item() + a[0].item() + a[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_aten_graph(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)()
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(
torch.tensor([[[1.3737, 0.1]]])
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_bypass_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_2_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(
torch.tensor([[[1.3737, 0.1]]])
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_stack_trace(self):
inp = torch.randn(4, 4)
class MyBlock(torch.nn.Module):
def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
return torch.cos(x).relu() + 1
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.block = MyBlock()
def forward(self, x):
out = self.block(x)
return out
exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op not in {"placeholder", "output"}:
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn_stack"] is not None)
torch._dynamo.reset()
exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op == "call_function":
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn_stack"] is not None)
self.assertTrue(node.meta["val"] is not None)
self.assertTrue(node.meta["original_aten"] is not None)
def test_export_preserves_nn_module_stack_for_get_attr(self):
inp = torch.randn(4, 4)
class MyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(1, 1))
self.register_buffer("buffer", torch.ones(1, 1))
def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
return torch.cos(x).relu() + self.weight + self.buffer
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.block = MyBlock()
def forward(self, x):
out = self.block(x)
return out
m = MyModule()
exported = torch._dynamo.export(m, aten_graph=False)(inp)
out_graph = exported[0]
attr_access_count = 0
for node in out_graph.graph.nodes:
if node.op == "get_attr":
attr_access_count += 1
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertEqual(attr_access_count, 2)
torch._dynamo.reset()
exported = torch._dynamo.export(m, aten_graph=True)(inp)
out_graph = exported[0]
attr_access_count = 0
for node in out_graph.graph.nodes:
if node.op == "get_attr":
attr_access_count += 1
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertEqual(attr_access_count, 2)
def test_export_compare_optimize_with_make_fx(self):
inp = torch.tensor([0.1, 0.1])
linear = torch.nn.Linear(2, 2)
def func(x):
x = x + 1
y = x.t()
y = y.relu()
y = linear(y)
return y
exported = torch._dynamo.export(func, aten_graph=True)(inp)
out_graph = exported[0]
export_result = out_graph(inp)
torch._dynamo.reset()
def compiler(gm, sample_inputs):
def fw(*args):
aten_gm = make_fx(gm)(*args)
return aten_gm(*args)
return fw
opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func)
make_fx_result_through_backend = opt_func(inp)
fx_g = make_fx(func)(inp)
make_fx_result_through_direct = fx_g(inp)
self.assertTrue(
torch._dynamo.utils.same(make_fx_result_through_backend, export_result)
)
self.assertTrue(
torch._dynamo.utils.same(make_fx_result_through_direct, export_result)
)
def test_export_with_constant_method_on_module(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_method_on_module_invoke_twice(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method_multiarg(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, z):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(module)(
torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def forward(self, x, z):
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(module)(
torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]),
torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_tuple_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return (torch.nonzero(x), torch.nonzero(x))
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_list_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return [torch.nonzero(x), torch.nonzero(x)]
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_list_nonzero_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return [torch.nonzero(x), torch.nonzero(x)]
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
elements = helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_dict_values(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return {"x": x, "x^2": x * x}
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
y = y * elements["x"]
y = y * elements["x^2"]
return y
module = MyModule()
real_result = module(torch.tensor([2.0, 2.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow_pos(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_return_const(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return self.val
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x == "A":
return y
return -1
module = MyModule()
module.val = "A"
resA = module(torch.tensor([2]))
graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
module.val = "B"
resB = graph(torch.tensor([2]))
self.assertTrue(torch._dynamo.utils.same(resA, resB))
def test_export_with_builtin_op_on_assume_constant(self):
@torch._dynamo.assume_constant_result
def get_y(y) -> torch.Tensor:
return y
class Bob(torch.nn.Module):
def __init__(self, p, val) -> None:
super().__init__()
self.p = p
self.y = torch.nn.Parameter(torch.tensor(val))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# This only looks dynamic but it's actually a constant value
if get_y(self.y) < self.p:
return torch.cat([x, x])
else:
return x
model = Bob(0.5, 0.3)
inp = torch.ones(3, 4)
graph, guards = torch._dynamo.export(model)(inp)
self.assertEqual(model(inp), graph(inp))
def test_export_decomp(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
graph, _ = torch._dynamo.export(
f,
aten_graph=True,
decomposition_table={torch.ops.aten.t.default: nop},
)(torch.randn(5))
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
0,
)
graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)(
torch.randn(5)
)
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
2,
)
def test_export_decomp_asserts_bad_args(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
with self.assertRaises(AssertionError):
graph, _ = torch._dynamo.export(
f,
(torch.randn(5)),
aten_graph=False,
decomposition_table={torch.ops.aten.t.default: nop},
)
@config.patch(capture_scalar_outputs=True)
def test_export_with_module_layer(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, pred, x):
def true_fn(val):
return self.linear(val) * torch.tensor(2)
def false_fn(val):
return self.linear(val) * torch.tensor(-1)
return cond(pred, true_fn, false_fn, [x])
mod = Module()
x = torch.randn([3, 3])
pred = torch.tensor(x[0][0].item() < 0)
real_result = mod.forward(pred, x)
torch._dynamo.reset()
exported = torch._dynamo.export(mod.forward)(pred, x)
out_graph = exported[0]
dynamo_result = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
# New X, just to show we did not specialize
x = x * -1
pred = torch.tensor(x[0][0].item() < 0)
real_result_2 = mod.forward(pred, x)
dynamo_result_2 = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
@config.patch(capture_scalar_outputs=True)
def test_export_with_cond_branches_calling_methods(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
# ok
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def t(self, val):
return val + 1
def f(self, val):
return val - 1
def true_fn(self, val):
return self.linear(val) + self.t(val)
def false_fn(self, val):
return self.linear(val) - self.f(val)
def forward(self, pred, x):
return cond(pred, self.true_fn, self.false_fn, [x])
mod = Module()
x = torch.randn([3, 3])
pred = torch.tensor(x[0][0].item() < 0)
real_result = mod.forward(pred, x)
out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
dynamo_result = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True)
def test_export_with_cond_closure(self):
from functorch.experimental.control_flow import cond
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x):
def true_fn(x):
return x * 2
def false_fn(x):
return x - 2
return cond(pred, true_fn, false_fn, [x])
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x):
def true_fn(x):
return x * 2
def false_fn(x):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
class FooBar(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, pred, x):
y = x + x
def true_fn(x, y):
return self.linear(x) * (x + y)
def false_fn(x, y):
return x * (y - x)
return cond(pred, true_fn, false_fn, [x, y])
for Module in [Foo, Bar, FooBar]:
mod = Module()
x = torch.randn([3, 3], requires_grad=True)
pred = torch.tensor(x[0][0].item() < 0)
real_result = mod.forward(pred, x)
out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
dynamo_result = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_cond_with_closed_function(self):
def hello(x):
return x + 1
def hi(x):
return x + 2
def foo(pred, x):
def true_fn(x):
return hello(x)
def false_fn(x):
return hi(x)
return cond(pred, true_fn, false_fn, [x])
x = torch.randn(5)
pred = x[0] > 0
real_result = foo(pred, x)
out_graph, _ = torch._dynamo.export(foo)(pred, x)
dynamo_result = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_cond_dynamic_shape_pred(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return x + x
def false_fn(x):
return x[:2]
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
class Module2(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return x + x
def false_fn(x):
return x[:2]
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
mods = [Module(), Module2()]
for mod in mods:
x = torch.randn(2, 2)
out_graph, guards = torch._dynamo.export(mod)(x)
self.assertExpectedInline(
out_graph.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
size = l_x_.size()
getitem = size[0]; size = None
le = getitem <= 2; getitem = None
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None
getitem_2 = cond[0]; cond = None
return pytree.tree_unflatten([getitem_2], self._out_spec)""",
)
self.assertExpectedInline(
out_graph.cond_true_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
add = l_x__1 + l_x__1; l_x__1 = None
return (add,)""",
)
self.assertExpectedInline(
out_graph.cond_false_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
return (getitem,)""",
)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
# True branch and false branch return tensors of different shape
torch._dynamo.export(mod)(torch.randn(3, 2))
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
# True branch and false branch return tensors of different shape
test_x = torch.randn(3, 2)
mod(test_x)
def test_export_with_map_cond(self):
from functorch.experimental.control_flow import cond, map
class Module(torch.nn.Module):
def inner(self, x, pred):
def true_fn(x):
return x + x
def false_fn(x):
return x * x
return cond(pred, true_fn, false_fn, [x])
def forward(self, pred, xs):
def body(x, pred):
return self.inner(x, pred)
return map(body, xs, pred)
mod = Module()
x = torch.randn(3, 2, 1)
pred_x = torch.tensor(True)
y = torch.randn(4, 3, 2)
pred_y = torch.tensor(False)
real_result = mod(pred_y, y)
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
self.assertEqual(real_result, out_graph(pred_y, y))
def test_export_with_map_zero_sized_tensor(self):
from functorch.experimental.control_flow import map
class Module(torch.nn.Module):
def forward(self, xs):
def body(x):
return x + 1
return map(body, xs)
mod = Module()
xs = torch.randn(0, 2)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"zero-sized tensor",
):
out_graph, _ = torch._dynamo.export(mod)(xs)
def test_export_meta_val(self):
def f(x, y, z):
return x * y + z
gm, _ = torch._dynamo.export(
f,
aten_graph=True,
)(
torch.ones(3, 2),
torch.zeros(3, 2),
torch.ones(3, 2),
)
for node in gm.graph.nodes:
if node.op == "placeholder":
self.assertIn("val", node.meta)
def test_input_container_type(self):
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
return {"a": x.sum() + sum(y).sum()}
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
self.assertEqual(gm(*inp), f(*inp))
@config.patch(assume_static_by_default=False)
def test_export_symbolic_shape(self):
def f(x: torch.Tensor) -> torch.Tensor:
return torch.empty(x.shape[0] * 2)
inp = (torch.randn(6, 5),)
gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
has_sym_size = False
for node in gm.graph.nodes:
if node.target is torch.ops.aten.sym_size.int:
has_sym_size = True
self.assertTrue(has_sym_size)
@config.patch(assume_static_by_default=False)
def test_dynamic_slicing(self):
def f(x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
inp = torch.randn(6, 7)
self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)
count = 0
# aten graph should flatten getitem calls to actual
# slice kernel call.
for node in gm_aten_mode.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.slice.Tensor
):
count += 1
self.assertEqual(count, 2)
gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5))
# In torch mode, the graph should contain 3 getitem methods
# one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
# this is because Tensor class has its' own getitem method
# which gets translated to aten.Slice later.
count = 0
for node in gm_torch_mode.graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
count += 1
self.assertEqual(count, 3)
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
def test_dynamic_slicing_invalid(self):
def g(x, y):
return x[y : x.shape[0]]
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Dynamic slicing on data-dependent value is not supported",
):
torch._dynamo.export(
g,
aten_graph=True,
)(
torch.randn(4, 5),
torch.tensor(2),
)
@config.patch(capture_scalar_outputs=True)
def test_dynamic_slicing_simple(self):
def f(x):
return x[slice(None, None, None)]
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))
def test_pre_dispatch_simple(self):
def f(x):
y = torch.ones_like(x)
return torch.matmul(x, y)
gm, _ = torch._dynamo.export(
f,
aten_graph=True,
pre_dispatch=True,
tracing_mode="fake",
)(
torch.randn(5, 5),
)
inp = torch.randn(6, 6)
self.assertEqual(gm(inp), f(inp))
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0_1 = arg0
ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None
return pytree.tree_unflatten([matmul], self._out_spec)""",
)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_cond_in_aten_symbolic(self):
class ConditionOp(torch.nn.Module):
def true_fn(self, x, y):
return x * y
def false_fn(self, x, y):
return x + y
def forward(self, pred, x, y):
return cond(pred, self.true_fn, self.false_fn, [x, y])
model = ConditionOp()
inp = (
torch.tensor(False),
torch.randn(4, 4),
torch.randn(4, 4),
)
gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp)
gm.print_readable()
self.assertEqual(gm(*inp), model(*inp))
def test_export_with_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
out *= mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
tuple0 = (torch.randn(4), torch.randn(4))
mykw0 = torch.randn(4)
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]
expected_argument_names = [
"pos0",
"tuple0",
"myargs_0",
"myargs_1",
"mykw0",
"input0",
"input1",
]
self._test_export_preserving_original_signature(
fn_with_kwargs,
expected_argument_names,
pos0,
tuple0,
*myargs,
mykw0=mykw0,
**mykwargs,
)
def test_export_with_kwargs_and_empty_args(self):
def fn_with_kwargs(mykw0=None, **mykwargs):
out = mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
mykw0 = torch.randn(4)
expected_argument_names = ["mykw0"] + list(mykwargs.keys())
self._test_export_preserving_original_signature(
fn_with_kwargs, expected_argument_names, mykw0, **mykwargs
)
def test_export_with_args_and_empty_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
return out
tuple0 = (torch.randn(4), torch.randn(4))
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]
expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"]
self._test_export_preserving_original_signature(
fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs
)
@common_utils.parametrize(
"default_value",
[
common_utils.subtest(None, name="None"),
common_utils.subtest(42.0, name="float"),
common_utils.subtest(
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
torch.randn(4),
name="tensor",
decorators=[unittest.expectedFailure],
),
common_utils.subtest(
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
(torch.randn(4),),
name="tuple",
decorators=[unittest.expectedFailure],
),
],
)
def test_export_with_args_with_default(self, default_value):
def fn(pos0, pos1_default=default_value):
out = pos0
if pos1_default is None:
pos1_default = torch.randn(4)
if isinstance(pos1_default, tuple):
pos1_default = pos1_default[0]
out *= pos1_default
return out
pos0 = torch.randn(4)
expected_argument_names = ["pos0"]
self._test_export_preserving_original_signature(
fn, expected_argument_names, pos0
)
@common_utils.parametrize(
"default_value",
[
common_utils.subtest(None, name="None"),
common_utils.subtest(42.0, name="float"),
common_utils.subtest(
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
torch.randn(4),
name="tensor",
decorators=[unittest.expectedFailure],
),
common_utils.subtest(
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
(torch.randn(4),),
name="tuple",
decorators=[unittest.expectedFailure],
),
],
)
def test_export_with_kwargs_with_default(self, default_value):
def fn(pos0, *, kw0, kw1_default=default_value, **kwargs):
out = pos0
out += kw0
if kw1_default is None:
kw1_default = torch.randn(4)
elif isinstance(kw1_default, tuple):
kw1_default = kw1_default[0]
out += kw1_default
out += kwargs["kw2"]
return out
pos0 = torch.randn(4)
kw0 = torch.randn(4)
kw2 = torch.randn(4)
args = (pos0,)
kwargs = {"kw0": kw0, "kw2": kw2}
expected_argument_names = ["pos0", "kw0", "kw2"]
self._test_export_preserving_original_signature(
fn, expected_argument_names, *args, **kwargs
)
def test_export_with_wrapped_fn(self):
# To ensure dynamo.export is robust to wrapped functions
# when it cannot use `inspect` to retrieve original signature
# info.
def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
out = pos0
out += pos1
out += kw0
out += kw1
for arg in args:
out += arg
for kwarg in kwargs.values():
out += kwarg
return out
def wrapped_fn(*args, **kwargs):
return _fn(*args, **kwargs)
pos0 = torch.randn(4)
kw0 = torch.randn(4)
args = (pos0, torch.randn(4), torch.randn(4))
kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
kwargs.keys()
)
self._test_export_preserving_original_signature(
wrapped_fn, expected_argument_names, *args, **kwargs
)
def test_export_with_functools_wrapped_method(self):
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
@test_decorator
def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
out = pos0
out += pos1
out += kw0
out += kw1
for arg in args:
out += arg
for kwarg in kwargs.values():
out += kwarg
return out
pos0 = torch.randn(4)
pos1 = torch.randn(4)
unnamed_pos = torch.randn(4)
kw0 = torch.randn(4)
args = (pos0, pos1, unnamed_pos)
kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)}
expected_argument_names = [
"pos0",
"pos1",
"args_0", # 3rd unnamed positional argument
] + list(kwargs.keys())
m = MyModule()
self._test_export_preserving_original_signature(
m.method_to_test, expected_argument_names, *args, **kwargs
)
def test_export_with_functools_wrapped_fn(self):
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@test_decorator
def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
out = pos0
out += pos1
out += kw0
out += kw1
for arg in args:
out += arg
for kwarg in kwargs.values():
out += kwarg
return out
def wrapped_fn(*args, **kwargs):
return _fn(*args, **kwargs)
pos0 = torch.randn(4)
kw0 = torch.randn(4)
args = (pos0, torch.randn(4), torch.randn(4))
kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
kwargs.keys()
)
self._test_export_preserving_original_signature(
wrapped_fn, expected_argument_names, *args, **kwargs
)
def _test_export_preserving_original_signature(
self, fn, expected_argument_names: Sequence[str], *args, **kwargs
):
torch._dynamo.reset()
exported = torch._dynamo.export(
fn,
*args,
**kwargs,
aten_graph=False,
)
out_graph = exported[0]
dynamo_result = out_graph(*args, **kwargs)
real_result = fn(*args, **kwargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
# Check that the exported graph preserves same argument names.
self.assertEqual(
inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names
)
def test_dataclass_input_output(self):
from dataclasses import dataclass
@dataclass
class Tensors:
x: torch.Tensor
y: torch.Tensor
def f(t):
return t.x + t.y
with self.assertRaisesRegex(
UserError,
"It looks like one of the inputs with type .*Tensors.* "
"is not supported or pytree-flattenable",
):
torch._dynamo.export(f, aten_graph=False)(
Tensors(x=torch.randn(10), y=torch.randn(10))
)
def f(x, y):
return Tensors(x=x.sin(), y=y.cos())
with self.assertRaisesRegex(
UserError,
"It looks like one of the outputs with type .*Tensors.* "
"is not supported or pytree-flattenable",
):
torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10))
def test_empty(self):
def f(x):
return x
exported = torch._dynamo.export(f)(torch.randn(3, 3))
out_graph = exported[0]
inp = torch.randn(3, 3)
self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp)))
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.ones(3, 3)
def forward(self):
return self.a
exported = torch._dynamo.export(M())()
out_graph = exported[0]
self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph()))
@unittest.skipIf(not TEST_CUDA, "No CUDA available.")
def test_export_with_parameters(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.features = torch.nn.Sequential(
torch.nn.Conv2d(
3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
),
torch.nn.ReLU(inplace=True),
)
def forward(self, x):
return self.features(x)
model = MyModule().eval().cuda()
random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),)
dim_x = torch.export.Dim("dim_x", min=1, max=32)
exp_program = torch.export.export(
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}
)
output_buffer = io.BytesIO()
# Tests if we can restore saved nn.Parameters when we load them again
torch.export.save(exp_program, output_buffer)
loaded_model = torch.export.load(output_buffer)
self.assertTrue(
isinstance(
loaded_model.module().get_parameter("features.0.weight"),
torch.nn.Parameter,
)
)
def test_export_fast_binary_broadcast_check(self):
# This test looks at the case where we erroneously create a guard
# when checking the equality of the operands' shape and the output
# shape during FakeTensor's binary op fast path.
class MyModel(torch.nn.Module):
def forward(self, a, b):
# final shape is (dim0, 4, 8)
# order matters since a & the output have the same shape
return b + a
a = torch.randn(100, 4, 8)
b = torch.randn(4, 8)
model = MyModel().eval().cuda()
batchsize = torch.export.Dim("dim0", min=3, max=1024)
dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}
torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec)
def test_export_meta(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.ones(2, 3))
def forward(self, x):
return self.p + x
with torch.device("meta"):
m = MyModule()
inp = torch.ones(2, 3, device="meta")
exported = torch._dynamo.export(m)(inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertEqual(dynamo_result, m(inp))
def test_constraint_violation_error_messages(self):
class Foo(torch.nn.Module):
def forward(self, x):
if x.shape[0] == x.shape[1] * 2:
return x + 1
else:
return x + 2
foo = Foo()
t = torch.zeros([8, 4])
dim0 = torch.export.Dim("dim0", min=3, max=10)
dim1 = torch.export.Dim("dim1")
dynamic_shapes = {"x": (dim0, dim1)}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Constraints violated .*!(.*\n)*.*"
"by dim0 = 2\\*dim1(.*\n)*.*"
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
):
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)
class Bar(torch.nn.Module):
def forward(self, x):
if x.shape[0] == 5:
return x + 1
else:
return x + 2
bar = Bar()
t = torch.zeros([5])
dim0 = torch.export.Dim("dim0", min=3, max=8)
dynamic_shapes = {"x": (dim0,)}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Not all values.*valid.*inferred to be a constant",
):
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes)
class Qux(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5 and x.shape[0] < 10:
return x + 1
else:
return x + 2
qux = Qux()
t = torch.zeros([7])
dim0 = torch.export.Dim("dim0", min=3, max=8)
dynamic_shapes = {"x": (dim0,)}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Not all values.*satisfy the generated guard",
):
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes)
def test_untracked_inputs_in_constraints(self):
from copy import copy
class Foo(torch.nn.Module):
def forward(self, x, y):
return y + 1
foo = Foo()
x = torch.randn(2)
y = torch.randn(5, 4)
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
example_inputs = (copy(x), y)
ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes)
ep.module()(torch.randn(3), y) # no specialization error
def test_export_raise_guard_full_constraint(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(x):
if x.shape[0] == 3:
return x.sin()
return x.cos()
torch._dynamo.export(my_dyn_fn)(y)
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
)(y)
def test_export_module_specify_constraints_signature(self):
y = torch.randn([3, 3, 3])
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] == 3:
return x.sin()
return x.cos()
mod = Mod()
torch._dynamo.export(mod)(y)
with self.assertRaisesRegex(ConstraintViolationError, "dimx = None # 3"):
torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))(
y
)
def test_export_raise_guard_partial_constraint(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(x):
if x.shape[0] > 3:
return x.sin()
return x.cos()
torch._dynamo.export(my_dyn_fn)(y)
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
)(y)
def test_export_raise_on_relationship(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(a, b, c):
if a.shape[0] == b.shape[1] == c.shape[2]:
return a.sin()
return a.cos()
torch._dynamo.export(my_dyn_fn)(y, y, y)
dim = torch.export.Dim("dim")
dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
dynamic_shapes = ({0: dim}, {1: dim}, {2: dim})
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
def test_export_no_raise(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(a, b, c):
if a.shape[1] == 3:
return a.cos()
return a * b * c
torch._dynamo.export(my_dyn_fn)(y, y, y)
dim = torch.export.Dim("dim")
dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
def test_export_multi_dynamic_dim_unsafe_relationship(self):
x = torch.randn([3, 3, 3])
y = torch.randn([2, 2, 2])
z = torch.randn([3, 3, 3])
def my_dyn_fn(a, b, c):
if a.shape[0] == c.shape[0]:
return a.cos()
return a * c, b
torch._dynamo.export(my_dyn_fn)(x, y, z)
dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz")
dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
dimz = dimx
dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
def test_remove_redundant_dynamic_dim_in_error_message(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] == y["k"].shape[0]:
return x + 1
else:
return x - 1
foo = Foo()
a = torch.randn(3)
b = torch.randn(3)
dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b")
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"):
torch.export.export(
foo,
(a, {"k": b}),
dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
)
def test_enforce_equalities(self):
class Bar(torch.nn.Module):
def forward(self, x, y):
return torch.matmul(x, y)
bar = Bar()
batch, size = torch.export.dims("batch", "size")
dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)}
x = torch.randn(10, 3, 3)
y = torch.randn(10, 3, 4)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
".*x.*size.*1.* = 3 is not equal to .*y.*size.*2.* = 4",
):
torch.export.export(
bar,
(x, y),
dynamic_shapes=dynamic_shapes,
)
y = torch.randn(10, 3, 3)
ebar = torch.export.export(
bar,
(x, y),
dynamic_shapes=dynamic_shapes,
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in ebar.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
@config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,
capture_scalar_outputs=True,
)
def test_export_preserve_constraints_as_metadata_scalar(self):
def f(x, y):
b = x.item()
torch._constrain_as_size(b)
return torch.empty((b, y.shape[0]))
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = [x, y]
dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
gm, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
aten_graph=True,
tracing_mode="symbolic",
)(*example_inputs)
constraints = torch.export.dynamic_shapes._process_dynamic_shapes(
f, example_inputs, dynamic_shapes=dynamic_shapes
)
self.assertEqual(
gm.meta["input_shape_constraints"],
[c.serializable_spec for c in constraints],
)
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,
capture_scalar_outputs=True,
)
def test_export_preserve_constraints_as_metadata_tensor(self):
def f(x):
b = x.nonzero()
torch._constrain_as_value(b.shape[0], min=2, max=5)
return b
y = torch.tensor([8, 8, 6])
gm, _ = torch._dynamo.export(
f,
aten_graph=True,
tracing_mode="symbolic",
)(y)
@config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,
capture_scalar_outputs=True,
)
def test_exported_graph_serialization(self):
def f(x, y):
b = x.item()
torch._constrain_as_size(b)
return torch.empty((b, y.shape[0]))
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = [x, y]
dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
gm, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
aten_graph=True,
tracing_mode="symbolic",
)(*example_inputs)
# Ensure the exported graph module with metadata is serializable,
# metadata won't be saved in the serialized module
buffer = io.BytesIO()
torch.save(gm, buffer)
def test_export_dynamic_dim_not_1(self):
x = torch.randn([1, 1, 1])
def my_dyn_fn(a):
if a.shape[0] != 1:
return a.cos()
return a * a
torch._dynamo.export(my_dyn_fn)(x)
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
)(x)
def test_symbool(self):
def f(x):
a = torch.scalar_tensor(x.shape[0] > 4)
return x.sin().sum() + a.sum()
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4)))
def test_export_multi_dynamic_dim_constraint(self):
x = torch.randn([3, 3, 3])
y = torch.randn([2, 2, 2])
z = torch.randn([3, 3, 3])
def my_dyn_fn(a, b, c):
if a.shape[0] == c.shape[0]:
return a.cos()
return a * c, b
torch._dynamo.export(my_dyn_fn)(x, y, z)
dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2")
dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None)
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0})
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
def test_export_dynamic_dim_raise_on_compound_range_constraint(self):
x = torch.ones(6, 4, 4)
with self.assertRaisesRegex(TypeError, "Cannot determine truth value"):
4 < dynamic_dim(x, 0) <= 6 # noqa: B015
def test_export_dynamic_dim_range_constraint(self):
x = torch.ones(6, 4, 4)
dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},)
def foo(x):
if x.shape[0] > 3: # ok
return x.sin()
return x.cos()
torch._dynamo.export(
foo,
dynamic_shapes=dynamic_shapes,
aten_graph=True,
)(x)
def bar(x):
if x.shape[0] > 5: # error
return x.sin()
return x.cos()
with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(
bar,
dynamic_shapes=dynamic_shapes,
aten_graph=True,
)(x)
def test_trivial_constraint(self):
class Foo(torch.nn.Module):
def forward(self, x):
# complex divisibility condition
if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0:
return x + 1
else:
return x - 1
foo = Foo()
class Bar(torch.nn.Module):
def forward(self, x):
# trivially true
if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0:
return x + 1
else:
return x - 1
bar = Bar()
class Qux(torch.nn.Module):
def forward(self, x):
# simple divisibility condition (not trivially true)
if (3 * x.shape[0]) % 2 == 0:
return x + 1
else:
return x - 1
qux = Qux()
x = torch.randn(12)
dim0 = torch.export.Dim("dim0", max=100)
dynamic_shapes = {"x": (dim0,)}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"must be specialized.*guards generated.*too complex",
):
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Not all values.*satisfy the generated guard",
):
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)
def test_list_contains(self):
def func(x):
assert x.size(-1) in [4, 5, 6], "bad"
return x + x
inps = (torch.randn(1, 5),)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_not_contains(self):
def func(x):
assert x.size(0) not in [4, 5, 6], "bad1"
assert "monkey" not in ["cow", "pig"], "bad2"
return x + x
inps = (torch.randn(1, 5),)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_identity(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
return x
torch._dynamo.reset()
exported, _ = torch._dynamo.export(func)(inp)
dynamo_result = exported(inp)
self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))
def test_export_specialized_int(self):
class Foo(torch.nn.Module):
def __init__(
self,
input_dim,
):
super().__init__()
self.torch_module = torch.nn.LayerNorm(
input_dim, eps=1e-5, elementwise_affine=True
)
self.int_val = 100
def forward(self, input):
return input.cos() * self.int_val * self.torch_module.eps
mod = Foo(128)
inp = torch.randn(3, 128)
# In export, int & float in forward should always be specialized
gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp)
count = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
count += 1
self.assertEqual(count, 1)
def test_export_with_nonzero_static(self):
class BasicModule(torch.nn.Module):
def __init__(self, static_size):
super().__init__()
self.static_size = static_size
def forward(self, x):
return torch.nonzero_static(x, size=self.static_size)
input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3)
static_sizes = 3, 4
for input_tensor, static_size in zip(input_tensors, static_sizes):
m = BasicModule(static_size)
gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor)
res = gm(input_tensor)
self.assertEqual(res.size(0), static_size)
self.assertTrue(
torch._dynamo.utils.same(
res, torch.nonzero_static(input_tensor, size=static_size)
)
)
def test_export_pass_arg_by_name(self):
class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_lin = torch.nn.Linear(3, 4, bias=True)
def forward(self, x):
return self.my_lin(x)
mod, input_tensor = BasicModule(), torch.randn(2, 3)
gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor)
ref = mod(x=input_tensor)
res = gm(x=input_tensor)
self.assertTrue(torch._dynamo.utils.same(ref, res))
def test_export_pass_arg_by_name_star_args(self):
class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_lin = torch.nn.Linear(3, 4, bias=True)
def forward(self, *args):
return self.my_lin(args[0]) * self.my_lin(args[1])
mod, input_tensor, input_tensor2 = (
BasicModule(),
torch.randn(2, 3),
torch.randn(2, 3),
)
gm, guard = torch._dynamo.export(mod, aten_graph=True)(
input_tensor, input_tensor2
)
ref = mod(input_tensor, input_tensor2)
res = gm(input_tensor, input_tensor2)
self.assertTrue(torch._dynamo.utils.same(ref, res))
def test_export_mark_dynamic_conflict_dynamic_dim(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(x):
if x.shape[0] > 3:
return x.sin()
return x.cos()
torch._dynamo.mark_dynamic(y, 0)
with self.assertRaisesRegex(
RuntimeError,
"Constraints violated",
):
torch._dynamo.export(
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},)
)(y)
def test_export_dynamic_dim_cleanup(self):
y = torch.randn([3, 3, 3])
def my_dyn_fn(x):
return x.cos()
torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))(
y
)
@config.patch(capture_dynamic_output_shape_ops=True)
def test_export_dynamic_control_flow_error(self):
def f(x):
if x.nonzero() > 3:
return x.cos()
return x.sin()
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Dynamic control flow is not supported at the moment",
):
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6))
@config.patch(assume_static_by_default=False)
def test_export_persist_assert(self):
def f(x):
assert x[0].sum() > 4, "Shape must be more than 4"
return x.cos() + x.sin()
gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
torch.ones(5, 4, 6)
)
def has_aten_op(gm, op):
for node in gm.graph.nodes:
if node.target == op:
return True
return False
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
gm.graph.eliminate_dead_code()
gm.recompile()
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
gm(torch.zeros(3, 4, 5))
@common_utils.parametrize(
"type_fn",
[
common_utils.subtest(type, name="builtin"),
common_utils.subtest(lambda obj: obj.__class__, name="attr"),
],
)
def test_access_class_method_from_user_class(self, type_fn):
class A:
@classmethod
def func(cls):
return torch.Tensor([4, 5])
def f(x):
a = A()
return x.sum() + type_fn(a).func().sum()
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
def test_not_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
def forward(self, x):
x.add_(2)
return x.sum() + self.buffer1.sum()
example_inputs = (torch.ones(1, 2, 3),)
gm, _ = torch._dynamo.export(
Foo(),
aten_graph=True,
tracing_mode="symbolic",
)(*example_inputs)
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
count += 1
self.assertEqual(count, 1)
test_inp = (torch.ones(1, 2, 3),)
test_inp_v2 = (torch.ones(1, 2, 3),)
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
def test_round_dynamic_shapes(self):
def f(x):
return x[: round(x.shape[0] / 2)]
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
def test_cond_supported_pred_types(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f_pred_traced_as_symnode_var(x):
return cond(x.shape[0] > 2, true_fn, false_fn, [x])
def f_pred_traced_as_tensor_var(x):
return cond(x.all(), true_fn, false_fn, [x])
def f_pred_complex_expression_traced_as_symnode_var(x):
return cond(
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
true_fn,
false_fn,
[x],
)
example_inputs = (torch.rand(5, 8),)
for f in [
f_pred_traced_as_symnode_var,
f_pred_traced_as_tensor_var,
f_pred_complex_expression_traced_as_symnode_var,
]:
gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
def test_mixed_real_and_fake_inputs(self):
class _TestPattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
def forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.conv.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.conv.weight.shape)
bias_shape[1] = -1
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
example_inputs = (torch.randn(1, 1, 3, 3),)
torch._dynamo.export(
_TestPattern(),
aten_graph=True,
)(*example_inputs)
@config.patch(
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
assume_static_by_default=False,
)
def test_sym_contains(self):
def f(x, y):
return x.size(0) in y
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3))
true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5))
false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2))
self.assertEqual(gm(*true_inp), f(*true_inp))
self.assertEqual(gm(*false_inp), f(*false_inp))
def test_cond_raise_user_error_on_missing_args(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x):
return cond(x.shape[0] > 10, true_fn, false_fn)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
TypeError,
r"cond\(\) missing 1 required positional argument: 'operands'",
):
f(*example_inputs)
def test_cond_raise_user_error_on_unsupported_pred(self):
def f_unsupported_pred(x):
pred = torch.nn.Module()
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
RuntimeError,
"Expected pred to be bool or tensor, but got Module()",
):
f_unsupported_pred(*example_inputs)
def test_cond_raise_user_error_on_non_list_operands(self):
def f_non_list_operands(x):
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
RuntimeError,
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
):
f_non_list_operands(*example_inputs)
def test_cond_raise_user_error_on_non_tensor_operands(self):
def f_non_tensor_operands(x):
a: float = 3.14
return cond(
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
RuntimeError,
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
):
f_non_tensor_operands(*example_inputs)
def test_cond_raise_user_error_on_branch_args_mismatch(self):
def true_fn(x, y):
return x.sin()
def false_fn(x):
return x.cos()
def f_branch_args_mismatch(x, y):
return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y])
example_inputs = (torch.rand(5), torch.rand(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compil",
):
torch._dynamo.export(
f_branch_args_mismatch,
aten_graph=True,
)(
*example_inputs,
)
@config.patch(suppress_errors=True)
def test_uncaptured_higher_order_op_error_not_suppresed(self):
def true_fn(x, y):
return x.sin()
def false_fn(x):
return x.cos()
def f_branch_args_mismatch(x, y):
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
example_inputs = (torch.rand(5), torch.rand(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch._dynamo.export(
f_branch_args_mismatch,
aten_graph=True,
)(
*example_inputs,
)
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
def f_branch_return_non_tensor(x):
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch._dynamo.export(
f_branch_return_non_tensor,
aten_graph=True,
)(*example_inputs)
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
def f_branch_return_multiple_tensors(pred, x, y):
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
gm, _ = torch._dynamo.export(
f_branch_return_multiple_tensors,
aten_graph=True,
)(*example_inputs)
self.assertEqual(
gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs)
)
def test_multiple_outputs_op_with_evaluator(self):
class TopKModel(torch.nn.Module):
def forward(self, x):
values, _ = torch.topk(x, 3)
return torch.sum(values)
x = torch.arange(1.0, 6.0, requires_grad=True)
torch._dynamo.export(TopKModel())(x)
def test_cond_raise_user_error_on_mismatch_return_length(self):
def true_fn(x):
return x
def false_fn(x):
return (x, x)
def f_mismatch_return_length(x):
return cond(torch.tensor(100), true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch._dynamo.export(
f_mismatch_return_length,
aten_graph=True,
)(*example_inputs)
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
def true_fn(x):
return torch.tensor([[3], [2]])
def false_fn(x):
return torch.tensor([3.14])
def f_return_tensor_mismatch(x):
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
*example_inputs,
)
def test_byte_tensor_does_not_crash(self):
# See https://github.com/pytorch/pytorch/issues/100455
def func(text):
tensor = torch.ByteTensor(list(bytes(text, "utf8")))
return tensor + tensor
text = "".join(chr(a % 90 + 40) for a in range(111))
opt_func = torch._dynamo.optimize("eager", dynamic=True)(func)
for i in [99, 100]:
input = text[:i]
opt_func(input)
def test_export_defaults_ok(self):
class DynamicSliceExportMod(torch.nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[: x.size(0) - i, i : x.size(2), i:3])
return tuple(results)
gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)(
torch.randn(5, 5, 5),
)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0_1 = arg0
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sub = sym_size_int - 1
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 2)
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int_1); slice_2 = None
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None
sub_1 = sym_size_int - 2
slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int_1); slice_5 = None
slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None
sub_2 = sym_size_int - 3; sym_size_int = None
slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int_1); slice_8 = sym_size_int_1 = None
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None
return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
)
def test_capture_symbolic_tracing_simple_within_fake_mode(self):
from torch._dynamo.output_graph import config
def f(x):
y = torch.randn(3)
return x + x * y
with fake_tensor.FakeTensorMode(
shape_env=ShapeEnv(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
),
):
x = torch.randn(3)
for aten_graph in [True, False]:
gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x)
self.assertTrue(
isinstance(gm, torch.fx.GraphModule),
msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_"
+ str(aten_graph),
)
def test_export_with_symbool_inputs(self):
def f(pred: bool, x: torch.Tensor):
if pred:
return x.sin()
else:
return x.cos()
x = torch.randn([3, 4])
def test_symbool_guards(
f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
):
shape_env = ShapeEnv()
with fake_tensor.FakeTensorMode(
shape_env=shape_env,
) as fake_mode:
fake_x = fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())],
),
)
for i, size in enumerate(size_tests):
pred = fake_x.size(0) == size
gm, guards = torch._dynamo.export(f)(pred, x)
actual = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(actual, exp_graph[i])
dynamo_shape_env_guards = [
guard
for guard in guards
if guard.guard_types is not None
and "SHAPE_ENV" in guard.guard_types
]
self.assertEqual(len(dynamo_shape_env_guards), 1)
guard_code_on_predicate = [
code
for code in dynamo_shape_env_guards[0].code_list
if "L['pred']" in code
]
self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
outter_shape_env_guards = [
str(guard.expr) for guard in shape_env.guards
]
self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])
true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
sin = l_x_.sin(); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
cos = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = [
"cast_symbool_to_symint_guardless(L['pred']) == 1",
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,
[3, 3, 4, 5],
[true_graph, true_graph, false_graph, false_graph],
[true_guard_code, true_guard_code, false_guard_code, false_guard_code],
# Outter shape env should have no guards in it because we never specialize on the outter symbool.
[[], [], [], []],
)
def test_invalid_input_global(self) -> None:
global bulbous_bouffant
bulbous_bouffant = torch.randn(3)
def f(y):
return bulbous_bouffant + y
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""\
G['bulbous_bouffant'], accessed at:
File "test_export.py", line N, in f
return bulbous_bouffant + y
""",
)
def test_invalid_input_global_multiple_access(self) -> None:
global macademia
macademia = torch.randn(3)
def g(y):
global macademia
y = macademia + y
return y
def f(y):
global macademia
y = g(y)
return macademia + y
# NB: This doesn't actually work (it only reports the first usage),
# but I'm leaving the test here in case we fix it later
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""\
G['macademia'], accessed at:
File "test_export.py", line N, in f
y = g(y)
File "test_export.py", line N, in g
y = macademia + y
""",
)
def test_invalid_input_nonlocal(self) -> None:
arglebargle = torch.randn(3)
def f(y):
return arglebargle + y
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""L['arglebargle'], a closed over free variable""",
)
def test_invalid_input_unused_nonlocal_ok(self) -> None:
arglebargle = torch.randn(3)
def f(y):
x = arglebargle
return y
torch._dynamo.export(f)(torch.randn(3))
def test_symbolic_tracing_within_fake_mode_with_constraints(self):
from torch._subclasses import fake_tensor
fake_mode = fake_tensor.FakeTensorMode()
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])
with fake_mode:
model = DynamicShapeSimpleModel()
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
dim = torch.export.Dim("dim")
dynamic_shapes = ({0: dim}, None, {0: dim})
for aten_graph in [True, False]:
gm = torch._dynamo.export(
model,
dynamic_shapes=dynamic_shapes,
aten_graph=aten_graph,
)(*inputs).graph_module
# Since there are no parameters we can do this
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
self.assertEqual(model(*inputs), gm(*inputs))
def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self):
from torch._subclasses import fake_tensor
fake_mode = fake_tensor.FakeTensorMode()
# TODO: Seems to choke if you don't make a fresh model and
# just try to export Linear directly...
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
out = self.linear(x)
return out
with fake_mode:
model = Model()
inputs = (torch.randn(10, 2, 2),)
dynamic_shapes = ({0: torch.export.Dim("dim")},)
for aten_graph in [True, False]:
gm = torch._dynamo.export(
model,
dynamic_shapes=dynamic_shapes,
aten_graph=aten_graph,
)(*inputs).graph_module
def test_capture_symbolic_tracing_within_fake_mode(self):
from torch._dynamo.output_graph import config
from torch._subclasses import fake_tensor
from torch.fx.experimental.symbolic_shapes import ShapeEnv
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
def forward(self, x):
out = self.linear(x)
out = self.linear2(out)
return out
# User-instantiated FakeTensorMode
fake_mode = fake_tensor.FakeTensorMode(
allow_non_fake_inputs=False,
allow_fallback_kernels=True,
shape_env=ShapeEnv(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
),
)
# Fakefy input+model before exporting it
with fake_mode:
x = torch.rand(5, 2, 2)
model = Model()
# Export the model with fake inputs and parameters
for aten_graph in [True, False]:
graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x)
self.assertTrue(
isinstance(graph_module, torch.fx.GraphModule),
msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_"
+ str(aten_graph),
)
def test_cond_op_param_buffer_lifted(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))
def forward(self):
return self.buffer1.sum()
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))
def forward(self):
return self.buffer2.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()
def forward(self, x):
def true_fn(x):
return x.cos() + self.a()
def false_fn(x):
return x.sin() + self.b()
return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
def test_nested_cond_op_param_buffer_lifted(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))
def forward(self):
return self.buffer1.sum()
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))
def forward(self):
return self.buffer2.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()
def forward(self, x):
def true_true_fn(x):
return x.cos() + self.a()
def true_false_fn(x):
return x.cos() + self.a() + 1
def true_fn(x):
return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sin() + self.b()
return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4)))
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
def test_map_cond_param_buffer_lifted(self):
from functorch.experimental.control_flow import cond, map
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))
def forward(self):
return self.buffer1.sum()
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))
def forward(self):
return self.buffer2.sum()
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()
def inner(self, x, pred):
def true_fn(x):
return x + x + self.a()
def false_fn(x):
return x * x + self.b()
return cond(pred, true_fn, false_fn, [x])
def forward(self, pred, xs):
def body(x, pred):
return self.inner(x, pred) + self.b()
return map(body, xs, pred)
mod = Module()
x = torch.randn(3, 2, 1)
pred_x = torch.tensor(True)
y = torch.randn(4, 3, 2)
pred_y = torch.tensor(False)
real_result = mod(pred_y, y)
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
self.assertEqual(real_result, out_graph(pred_y, y))
def test_cond_free_variables_overlapping(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x):
a = torch.ones(6, 4)
b = torch.ones(6, 4)
c = torch.ones(6, 4)
d = torch.ones(6, 4)
def true_fn(x):
return x + x + a.cos() + b.cos() + d.cos()
def false_fn(x):
return x * x + a.sin() + b.sin() + c.sin()
return cond(pred, true_fn, false_fn, [x])
mod = Module()
x = torch.ones(6, 4)
pred_x = torch.tensor(True)
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
self.assertExpectedInline(
out_graph.code.strip(),
"""\
def forward(self, pred, x):
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_pred_ = arg0
l_x_ = arg1
a = torch.ones(6, 4)
b = torch.ones(6, 4)
c = torch.ones(6, 4)
d = torch.ones(6, 4)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
getitem = cond[0]; cond = None
return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122
)
self.assertExpectedInline(
out_graph.cond_true_0.code.strip(),
"""\
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
a_1 = a
b_1 = b
l_x__1 = l_x_
add = l_x__1 + l_x__1; l_x__1 = None
cos = a_1.cos(); a_1 = None
add_1 = add + cos; add = cos = None
cos_1 = b_1.cos(); b_1 = None
add_2 = add_1 + cos_1; add_1 = cos_1 = None
cos_2 = d_true_branch.cos(); d_true_branch = None
add_3 = add_2 + cos_2; add_2 = cos_2 = None
return (add_3,)""",
)
self.assertExpectedInline(
out_graph.cond_false_0.code.strip(),
"""\
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
a_1 = a
b_1 = b
l_x__1 = l_x_
mul = l_x__1 * l_x__1; l_x__1 = None
sin = a_1.sin(); a_1 = None
add = mul + sin; mul = sin = None
sin_1 = b_1.sin(); b_1 = None
add_1 = add + sin_1; add = sin_1 = None
sin_2 = c_false_branch.sin(); c_false_branch = None
add_2 = add_1 + sin_2; add_1 = sin_2 = None
return (add_2,)""",
)
@unittest.skipIf(
common_utils.TEST_WITH_ASAN,
"Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416",
)
def test_retracibility(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
a, b = x
a_conv = self.conv(a)
a_linear = self.linear(a_conv)
b_conv = self.conv(b)
b_linear = self.linear(b_conv)
return (
a_linear.cos() + b_linear.sin(),
a_linear.sin() + b_linear.cos(),
)
inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0]))
self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1]))
def test_retracibility_dict_container_inp_out(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
a1, a2 = x["a"]
b = x["b"]
a1_conv = self.conv(a1)
a1_linear = self.linear(a1_conv)
a2_conv = self.conv(a2)
a2_linear = self.linear(a2_conv)
b_conv = self.conv(b)
b_linear = self.linear(b_conv)
return {
"a": [
a1_linear.cos() + b_linear.sin(),
a1_linear.cos() + b_linear.sin(),
],
"b": a2_linear.sin() + b_linear.cos(),
}
inp_container = {
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
"b": torch.randn(20, 16, 50, 100),
}
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
inp_test = {
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
"b": torch.randn(20, 16, 50, 100),
}
self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0]))
self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1]))
self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"]))
def test_retracibility_nested_list_out(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
a1, a2 = x["a"]
b = x["b"]
a1_conv = self.conv(a1)
a1_linear = self.linear(a1_conv)
a2_conv = self.conv(a2)
a2_linear = self.linear(a2_conv)
b_conv = self.conv(b)
b_linear = self.linear(b_conv)
return [
[
a1_linear.cos() + b_linear.sin(),
a1_linear.cos() + b_linear.sin(),
],
[
a2_linear.sin() + b_linear.cos(),
a2_linear.sin() + b_linear.cos(),
],
]
inp_container = {
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
"b": torch.randn(20, 16, 50, 100),
}
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
inp_test = {
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
"b": torch.randn(20, 16, 50, 100),
}
self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0]))
self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1]))
self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0]))
self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1]))
def test_fx_pytree(self):
def foo(args):
flat_args, spec = torch.utils._pytree.tree_flatten(args)
flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec)
return flat_args_fx[0] + flat_args[0]
inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True)
self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))
@config.patch(suppress_errors=True)
@config.patch(verbose=True)
def test_export_with_map_zero_sized_tensor_suppress_errors(self):
from functorch.experimental.control_flow import map
class Module(torch.nn.Module):
def forward(self, xs):
def body(x):
return x + 1
return map(body, xs)
mod = Module()
xs = torch.randn(0, 2)
with self.assertRaises(
torch._dynamo.exc.Unsupported,
):
out_graph, _ = torch._dynamo.export(mod, xs)
def test_param_buffer_safe_from_mutation_simple(self):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(5, 5))
def forward(self, x):
self.buffer1.add_(1)
return x + self.buffer1
gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False)
buffers = list(gm.named_buffers())
self.assertEqual(len(buffers), 1)
name, buffer = buffers[0]
self.assertEqual(name, "L__self___buffer1")
self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
def test_param_buffer_safe_from_mutation_recurse(self):
class Child(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.zeros(5))
def forward(self, x):
return x.sum() + self.buffer2.sum()
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(5))
self.child = Child()
def forward(self, x):
self.buffer1.add_(1)
self.child.buffer2.add_(2)
return x.sum() + self.buffer1.sum() + self.child(x)
gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False)
for name, buffer in gm.named_buffers():
self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
def test_predispatch_with_higher_order(self):
def f(x):
return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x])
gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
torch.randn(4, 4)
)
inp1 = torch.randn(4, 4)
inp2 = torch.randn(6, 4)
self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
def test_predispatch_with_higher_order_nested(self):
def f(x):
def true_fn(x):
return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x])
return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x])
gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
torch.randn(4, 4)
)
inp1 = torch.randn(4, 4)
inp2 = torch.randn(6, 4)
inp3 = torch.randn(8, 4)
self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
self.assertTrue(torch.allclose(f(inp3), gm(inp3)))
def test_predispatch_with_for_out_dtype(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x):
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight)
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
self.assertTrue(torch.allclose(m(x), gm(x)))
def test_predispatch_with_for_out_dtype_nested(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def true_fn(self, x):
return out_dtype(
torch.ops.aten.mm.default, torch.int32, x, self.weight
).sum()
def false_fn(self, x):
return out_dtype(
torch.ops.aten.mul.Tensor, torch.int32, x, self.weight
).sum()
def forward(self, x):
return cond(x.sum() != 0, self.true_fn, self.false_fn, [x])
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.ones((5, 5), dtype=torch.int8)
gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
self.assertTrue(torch.allclose(m(x), gm(x)))
y = torch.zeros((5, 5), dtype=torch.int8)
self.assertTrue(torch.allclose(m(y), gm(y)))
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None
sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None
return (sum_1,)""",
)
self.assertExpectedInline(
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None
sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None
return (sum_1,)""",
)
def test_export_nn_module_stack_patched_module(self):
def forward(self, x, y):
return x * y
class Toplevel(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, x, y):
return self.m(x, y)
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
t = Toplevel(M())
t.m.forward = forward.__get__(t.m, M)
x, y = torch.rand(3), torch.rand(3)
gm, _ = torch._dynamo.export(t, x, y)
self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y)))
for node in gm.graph.nodes:
if node.op == "call_function":
self.assertIn("nn_module_stack", node.meta)
def test_preserve_fx_node_metadata(self):
class Module1(torch.nn.Module):
def forward(self, x):
return torch.sin(x)
class Module2(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod1 = Module1()
def forward(self, x):
x = torch.cos(x)
x = self.mod1(x)
x = torch.relu(x)
return x
def fn(x):
return torch.abs(x)
mod = Module2()
inp = torch.randn(3, 3)
gm, _ = torch._dynamo.export(mod)(inp)
# replace relu with fn
gm_edit = copy.deepcopy(gm)
for nd in gm_edit.graph.nodes:
if nd.target == torch.relu:
nd.target = fn
nd.meta.clear()
break
gm_edit.recompile()
gm2, _ = torch._dynamo.export(gm_edit)(inp)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
x_2 = torch.relu(x_1); x_1 = None
return pytree.tree_unflatten([x_2], self._out_spec)""",
)
def _constais_op(gm, target):
for nd in gm.graph.nodes:
if nd.target == target:
return True
return False
self.assertTrue(_constais_op(gm_edit, torch.cos))
self.assertTrue(_constais_op(gm_edit, torch.sin))
self.assertTrue(not _constais_op(gm_edit, torch.relu))
self.assertExpectedInline(
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
x_2 = torch.abs(x_1); x_1 = None
return pytree.tree_unflatten([x_2], self._out_spec)""",
)
# check for other metadata
for op in (torch.sin, torch.cos):
nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes))
nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes))
self.assertTrue(
("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta)
)
if "nn_module_stack" in nd1.meta:
self.assertEqual(
nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
)
self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])
def test_preserve_fx_node_metadata_recompile(self):
def fn(x):
return torch.sin(x)
gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
do_export = torch._dynamo.export(gm)
torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3))
gm1, _ = do_export(torch.randn(3, 3))
gm2, _ = do_export(torch.randn(5, 3))
self.assertExpectedInline(
gm1.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
)
self.assertExpectedInline(
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
)
def test_preserve_fx_node_metadata_inline(self):
def f1(x):
return torch.sin(x)
gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3))
def f2(x):
x = torch.cos(x)
return gm(x)
gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3))
self.assertExpectedInline(
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
sin = torch.sin(x); x = None
return pytree.tree_unflatten([sin], self._out_spec)""",
)
def test_preserve_fx_node_metadata_graph_break(self):
def fn(x):
x = torch.sin(x)
x = torch.abs(x)
return torch.cos(x)
def bad_fn(x):
torch._dynamo.graph_break()
return x
gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
# replace abs with graph break
gm_edit = copy.deepcopy(gm)
for nd in gm_edit.graph.nodes:
if nd.target == torch.abs:
nd.target = bad_fn
nd.meta.clear()
break
gm_edit.recompile()
expected = [
"x = torch.sin(l_x_)",
"cos = torch.cos(l_stack0_)",
]
def test_backend(gm: torch.fx.GraphModule, example_inputs):
self.assertTrue(expected)
self.assertIn(expected[0], gm.print_readable(print_output=False))
expected.pop(0)
return gm.forward
torch._dynamo.reset()
opt_gm_edit = torch.compile(gm_edit, backend=test_backend)
opt_gm_edit(torch.randn(3, 3))
def test_torch_inference_mode_ctx(self):
@torch.inference_mode()
def fn(x):
return x + 1
gm, _ = torch._dynamo.export(fn, torch.rand(2, 2))
inp = torch.randn(2, 2)
out = gm(inp)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_args_0_ + 1; l_args_0_ = None
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None
return pytree.tree_unflatten([add], self._out_spec)""",
)
self.assertEqual(out.requires_grad, False)
with self.assertRaisesRegex(
RuntimeError,
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.",
):
out.requires_grad = True
@torch.inference_mode(False)
def fn_no_inference(x):
return x + 1
gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2))
self.assertExpectedInline(
gm_no_inference.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
add = l_args_0_ + 1; l_args_0_ = None
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None
return pytree.tree_unflatten([add], self._out_spec)""",
)
inp = torch.randn(2, 2)
out = gm_no_inference(inp)
self.assertEqual(out.requires_grad, False)
out.requires_grad = True
def fn(x):
with torch.inference_mode():
return x + 1
gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2))
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_x_ + 1; l_x_ = None
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None
return pytree.tree_unflatten([add], self._out_spec)""",
)
inp = torch.randn(2, 2, requires_grad=True)
out = gm(inp)
self.assertEqual(out.requires_grad, False)
def test_export_masking_with_no_grad(self):
def fn(x, b, y):
x = x.clone()
x[b] = y
return x
def fn_no_grad(x, b, y):
with torch.no_grad():
return fn(x, b, y)
def fn_inference_mode(x, b, y):
with torch.inference_mode():
return fn(x, b, y)
x = torch.randn(4, requires_grad=True)
b = torch.tensor([True, False, True, False])
y = torch.randn(2, requires_grad=True)
gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
_set_grad_enabled = torch._C._set_grad_enabled(False)
x = l_x_.clone(); l_x_ = None
x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
return pytree.tree_unflatten([x], self._out_spec)""",
)
gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
x = l_x_.clone(); l_x_ = None
x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None
return pytree.tree_unflatten([x], self._out_spec)""",
)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "boolean masking setitem backwards"
):
gm, _ = torch._dynamo.export(fn)(x, b, y)
def test_dynamo_list_index(self):
def fn(x, in_list):
return x + in_list.index(2)
inputs = (torch.ones(2, 2), [1, 2])
graph, _ = torch._dynamo.export(fn)(*inputs)
out = graph(*inputs)
self.assertEqual(out, torch.ones(2, 2) + 1)
common_utils.instantiate_parametrized_tests(ExportTests)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()