blob: 40691482bc1e27d8d83f266d860953a91f226458 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import operator
from enum import Enum
from typing import Dict, List
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch.fx.experimental.proxy_tensor import make_fx
class ExportTests(torch._dynamo.test_case.TestCase):
# TODO(voz): Refactor to a shared test function.
# The tests in this file are a little redundant,
# They all take a func, run it with eager, then export it, then compare
def test_export(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func)
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(dynamic_shapes=True)
def test_export_shape_control_flow_1(self):
def func(x):
if x.shape[0] > 10:
return x.cos()
return x.sin()
opt_func = torch._dynamo.optimize("eager")(func)
real_result = opt_func(torch.ones(6, 4))
torch._dynamo.reset()
exported = torch._dynamo.export(func, torch.ones(6, 4))
out_graph, out_guards = exported
dynamo_result = out_graph(torch.ones(6, 4))
from torch._guards import GuardSource
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
hit = False
for guard in out_guards:
if guard.source == GuardSource.SHAPE_ENV:
hit = True
self.assertTrue("x.size()[0] <= 10" in guard.code_list)
self.assertTrue(hit)
def test_export_control_flow_with_getattr(self):
class Animal(Enum):
COW = "moo"
class MyModule(torch.nn.Module):
def __init__(self, a):
super().__init__()
self.a = a
def forward(self, x):
if self.a == Animal.COW.value:
return x * x
else:
raise ValueError("bad")
module = MyModule("moo")
input = (torch.ones(4, 3),)
resA = module(*input)
graph, _ = torch._dynamo.export(module, *input)
resB = graph(*input)
self.assertTrue(torch._dynamo.utils.same(resA, resB))
def test_export_graph_bypass(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_2(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_dupes_and_bypass_with_non_tensor_output(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_zeroes_in_new_shape_scalar_out(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a[0].item() + b[0].item() + c[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_zeroes_in_new_shape_scalar_out_permute(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return b[0].item() + c[0].item() + a[0].item() + a[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_aten_graph(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(
func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_bypass_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_2_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(
func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
dynamo_result = out_graph(*inps)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph(*inps_rand)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_stack_trace(self):
inp = torch.randn(4, 4)
class MyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
return torch.cos(x).relu() + 1
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.block = MyBlock()
def forward(self, x):
out = self.block(x)
return out
exported = torch._dynamo.export(MyModule(), inp, aten_graph=False)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op not in {"placeholder", "output"}:
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn"] is not None)
torch._dynamo.reset()
exported = torch._dynamo.export(MyModule(), inp, aten_graph=True)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op == "call_function":
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn"] is not None)
def test_export_compare_optimize_with_make_fx(self):
inp = torch.tensor([0.1, 0.1])
linear = torch.nn.Linear(2, 2)
def func(x):
x = x + 1
y = x.t()
y = y.relu()
y = linear(y)
return y
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
export_result = out_graph(inp)
torch._dynamo.reset()
def compiler(gm, sample_inputs):
def fw(*args):
aten_gm = make_fx(gm)(*args)
return aten_gm(*args)
return fw
opt_func = torch._dynamo.optimize(compiler, nopython=True)(func)
make_fx_result_through_backend = opt_func(inp)
fx_g = make_fx(func)(inp)
make_fx_result_through_direct = fx_g(inp)
self.assertTrue(
torch._dynamo.utils.same(make_fx_result_through_backend, export_result)
)
self.assertTrue(
torch._dynamo.utils.same(make_fx_result_through_direct, export_result)
)
def test_export_with_constant_method_on_module(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_method_on_module_invoke_twice(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method_multiarg(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, z):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(
module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, z):
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(
module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]),
torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_tuple_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return (torch.nonzero(x), torch.nonzero(x))
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_list_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return [torch.nonzero(x), torch.nonzero(x)]
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_list_nonzero_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return [torch.nonzero(x), torch.nonzero(x)]
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
elements = helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_dict_values(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return {"x": x, "x^2": x * x}
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
y = y * elements["x"]
y = y * elements["x^2"]
return y
module = MyModule()
real_result = module(torch.tensor([2.0, 2.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2.0, 2.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow_pos(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_export_with_constant_not_return_const(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return self.val
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x == "A":
return y
return -1
module = MyModule()
module.val = "A"
resA = module(torch.tensor([2]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
module.val = "B"
resB = graph(torch.tensor([2]))
self.assertTrue(torch._dynamo.utils.same(resA, resB))
def test_export_decomp(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
graph, _ = torch._dynamo.export(
f,
(torch.randn(5)),
aten_graph=True,
decomposition_table={torch.ops.aten.t.default: nop},
)
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
0,
)
graph, _ = torch._dynamo.export(
f, (torch.randn(5)), aten_graph=True, decomposition_table=None
)
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
2,
)
def test_export_decomp_asserts_bad_args(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
with self.assertRaises(AssertionError):
graph, _ = torch._dynamo.export(
f,
(torch.randn(5)),
aten_graph=False,
decomposition_table={torch.ops.aten.t.default: nop},
)
def test_export_decomp_asserts_bad_args_mode(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
with self.assertRaises(AssertionError):
graph, _ = torch._dynamo.export(
f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic"
)
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_export_with_module_layer(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, pred, x):
def true_fn(val):
return self.linear(val) * torch.tensor(2)
def false_fn(val):
return self.linear(val) * torch.tensor(-1)
return cond(pred, true_fn, false_fn, [x])
mod = Module()
x = torch.randn([3, 3])
pred = torch.tensor(x[0][0].item() < 0)
real_result = mod.forward(pred, x)
torch._dynamo.reset()
exported = torch._dynamo.export(mod.forward, pred, x)
out_graph = exported[0]
dynamo_result = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
# New X, just to show we did not specialize
x = x * -1
pred = torch.tensor(x[0][0].item() < 0)
real_result_2 = mod.forward(pred, x)
dynamo_result_2 = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
@config.patch(dynamic_shapes=True)
def test_export_with_cond_dynamic_shape_pred(self):
from functorch.experimental.control_flow import cond
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
def true_fn(x):
return x + x
def false_fn(x):
return x[:2]
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
mod = Module()
x = torch.randn(2, 2)
out_graph, _ = torch._dynamo.export(mod, x)
test_x = torch.randn(3, 2)
self.assertEqual(out_graph(test_x), mod(test_x))
@config.patch(dynamic_shapes=True)
def test_export_with_map_cond(self):
from functorch.experimental.control_flow import cond, map
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def inner(self, x, pred):
def true_fn(x):
return x + x
def false_fn(x):
return x * x
return cond(pred, true_fn, false_fn, [x])
def forward(self, pred, xs):
def body(x, pred):
return self.inner(x, pred)
return map(body, xs, pred)
mod = Module()
x = torch.randn(3, 2, 1)
pred_x = torch.tensor(True)
y = torch.randn(4, 3, 2)
pred_y = torch.tensor(False)
real_result = mod(pred_y, y)
out_graph, _ = torch._dynamo.export(mod, pred_x, x)
self.assertEqual(real_result, out_graph(pred_y, y))
@config.patch(dynamic_shapes=True)
def test_export_with_map_zero_sized_tensor(self):
from functorch.experimental.control_flow import map
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, xs):
def body(x):
return x + 1
return map(body, xs)
mod = Module()
xs = torch.randn(0, 2)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"zero-sized tensor",
):
out_graph, _ = torch._dynamo.export(mod, xs)
def test_export_meta_val(self):
def f(x, y, z):
return x * y + z
gm, _ = torch._dynamo.export(
f,
torch.ones(3, 2),
torch.zeros(3, 2),
torch.ones(3, 2),
aten_graph=True,
tracing_mode="symbolic",
)
for node in gm.graph.nodes:
if node.op == "placeholder":
self.assertIn("val", node.meta)
def test_input_container_type(self):
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
return {"a": x.sum() + sum(y).sum()}
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic")
self.assertEqual(gm(*inp), f(*inp))
def test_export_symbolic_shape(self):
def f(x: torch.Tensor) -> torch.Tensor:
return torch.empty(x.shape[0] * 2)
inp = (torch.randn(6, 5),)
gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic")
has_sym_size = False
for node in gm.graph.nodes:
if node.target is torch.ops.aten.sym_size:
has_sym_size = True
self.assertTrue(has_sym_size)
@config.patch(dynamic_shapes=True)
def test_dynamic_slicing(self):
def f(x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
gm_aten_mode, _ = torch._dynamo.export(
f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic"
)
inp = torch.randn(6, 7)
self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)
count = 0
# aten graph should flatten getitem calls to actual
# slice kernel call.
for node in gm_aten_mode.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.slice.Tensor
):
count += 1
self.assertEqual(count, 2)
gm_torch_mode, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=False)
# In torch mode, the graph should contain 3 getitem methods
# one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
# this is because Tensor class has its' own getitem method
# which gets translated to aten.Slice later.
count = 0
for node in gm_torch_mode.graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
count += 1
self.assertEqual(count, 3)
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
@config.patch(dynamic_shapes=True)
def test_dynamic_slicing_invalid(self):
def g(x, y):
return x[y : x.shape[0]]
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Dynamic slicing on data-dependent value is not supported",
):
torch._dynamo.export(
g,
torch.randn(4, 5),
torch.tensor(2),
aten_graph=True,
tracing_mode="symbolic",
)
@config.patch(capture_scalar_outputs=True, dynamic_shapes=True)
def test_dynamic_slicing_simple(self):
def f(x):
return x[slice(None, None, None)]
gm, _ = torch._dynamo.export(
f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic"
)
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_cond_in_aten_symbolic(self):
class ConditionOp(torch.nn.Module):
def __init__(self):
super().__init__()
def true_fn(self, x, y):
return x * y
def false_fn(self, x, y):
return x + y
def forward(self, pred, x, y):
return cond(pred, self.true_fn, self.false_fn, [x, y])
model = ConditionOp()
inp = (
torch.tensor(False),
torch.randn(4, 4),
torch.randn(4, 4),
)
gm, _ = torch._dynamo.export(
model, *inp, aten_graph=True, tracing_mode="symbolic"
)
gm.print_readable()
self.assertEqual(gm(*inp), model(*inp))
def test_export_with_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
out *= mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
tuple0 = (torch.randn(4), torch.randn(4))
mykw0 = torch.randn(4)
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]
torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs,
pos0,
tuple0,
*myargs,
aten_graph=False,
mykw0=mykw0,
**mykwargs,
)
out_graph = exported[0]
dynamo_result = out_graph(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs)
real_result = fn_with_kwargs(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_kwargs_and_empty_args(self):
def fn_with_kwargs(mykw0=None, **mykwargs):
out = mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
mykw0 = torch.randn(4)
torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs,
aten_graph=False,
mykw0=mykw0,
**mykwargs,
)
out_graph = exported[0]
dynamo_result = out_graph(mykw0=mykw0, **mykwargs)
real_result = fn_with_kwargs(mykw0=mykw0, **mykwargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_args_and_empty_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
return out
tuple0 = (torch.randn(4), torch.randn(4))
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]
torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs, pos0, tuple0, *myargs, aten_graph=False
)
out_graph = exported[0]
dynamo_result = out_graph(pos0, tuple0, *myargs)
real_result = fn_with_kwargs(pos0, tuple0, *myargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_meta(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.ones(2, 3))
def forward(self, x):
return self.p + x
with torch.device("meta"):
m = MyModule()
inp = torch.ones(2, 3, device="meta")
exported = torch._dynamo.export(m, inp)
out_graph = exported[0]
dynamo_result = out_graph(inp)
self.assertEqual(dynamo_result, m(inp))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()