blob: 1532267a043d78ac941c0d8693416f91da9ed5c3 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.optimizations.training import is_aot_autograd_safe_to_run
from torch._dynamo.testing import CompileCounter, rand_strided
def compiler_safe_fn(gm, example_inputs, is_safe):
is_safe[0] = is_aot_autograd_safe_to_run(gm, example_inputs)
return gm.forward
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
def test_LSTM(self):
# https://github.com/pytorch/torchdynamo/issues/1147
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_mod_model_lstm_lstm = torch.nn.LSTM(
64, 64, num_layers=2, bidirectional=True
)
def forward(self, permute: torch.Tensor):
self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute)
return (self_mod_model_lstm_lstm,)
is_safe = [True]
mod = Repro()
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_mod = torch._dynamo.optimize(compiler_fn)(mod)
args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
aot_mod(*args)
self.assertTrue(not is_safe[0])
def test_mutation(self):
# https://github.com/pytorch/torchdynamo/issues/1301
def fn(param, y):
prev_grad = torch.is_grad_enabled()
try:
torch.set_grad_enabled(False)
param.add_(y)
finally:
torch.set_grad_enabled(prev_grad)
return y
y = torch.randn(4)
x = torch.nn.Parameter(torch.randn(4))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(not is_safe[0])
def test_mutation1(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(not is_safe[0])
def test_negative_testing_mutation(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0 = torch.sin(_stack0)
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(is_safe[0])
def test_negative_testing(self):
def fn(x, y):
return torch.sin(x).add_(y)
y = torch.randn(4)
x = torch.randn(4)
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(is_safe[0])
def test_call_fn_with_non_const_inputs_aot_safe(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self):
super(ModuleSpecialFwd, self).__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=20, kernel_size=(5, 5)
)
def _conv_forward(self, x):
return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
def forward(self, x):
return self._conv_forward(x)
# Init mod
mod = ModuleSpecialFwd()
rx = torch.randn([3, 10, 10])
# Run it for real
real = mod(rx)
# Run it in export
graph, _ = torch._dynamo.export(mod, rx)
# Run exported graph with AOT
is_safe = [True]
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(graph)
aot_fn(rx)
self.assertTrue(is_safe[0])
def test_call_fn_with_non_const_inputs_aot_unsafe(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self):
super(ModuleSpecialFwd, self).__init__()
def _some_bad_fwd(self, param, y):
prev_grad = torch.is_grad_enabled()
try:
torch.set_grad_enabled(False)
param.add_(y)
finally:
torch.set_grad_enabled(prev_grad)
return y
def forward(self, x, y):
return self._some_bad_fwd(x, y)
# Init mod
mod = ModuleSpecialFwd()
x = torch.nn.Parameter(torch.randn(4))
y = torch.randn([4])
# Run it for real
real = mod(x, y)
# Run it in export
graph, _ = torch._dynamo.export(mod, x, y)
# Assert equal
self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))
# Run exported graph with AOT
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(graph)
aot_fn(x, y)
self.assertTrue(not is_safe[0])
def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self):
super(ModuleSpecialFwd, self).__init__()
def _some_bad_fwd(self, param, y):
if y[0][0] < 3:
return y + param
return param * y
def forward(self, x, y):
a = x * y
a = self._some_bad_fwd(a, a)
b = x + y
return a * b
# Init mod
mod = ModuleSpecialFwd()
x = torch.nn.Parameter(torch.randn([2, 2]))
y = torch.randn([2, 2])
# Run it for real
real = mod(x, y)
# Run it through optimize, with our capturing fn
gms = []
counter = CompileCounter()
def capturing_fn(gm, inputs):
nonlocal gms
gms.append(gm)
return counter(gm, inputs)
optimized_mod = torch._dynamo.optimize(capturing_fn)(mod)
# Assert equal
self.assertTrue(torch._dynamo.testing.same(real, optimized_mod(x, y)))
# Uncomment to reproduce commented out graphs below.
# for gm in gms:
# print("GM CODE", gm.code)
self.assertEqual(counter.frame_count, 4)
self.assertEqual(counter.op_count, 7)
# Graph 1
# def forward(self, x : torch.nn.parameter.Parameter, y : torch.Tensor):
# mul = x * y; x = y = None
# return (mul,)
# BREAK
# Graph 2
# def forward(self, y : torch.Tensor):
# getitem = y[0]; y = None
# getitem_1 = getitem[0]; getitem = None
# lt = getitem_1 < 3; getitem_1 = None
# return (lt,)
# BREAK
# Graph 3
# def forward(self, param : torch.Tensor, y : torch.Tensor):
# add = y + param; y = param = None
# return (add,)
# BREAK
# Graph 4
# def forward(self, _stack0 : torch.Tensor, x : torch.nn.parameter.Parameter, y : torch.Tensor):
# add = x + y; x = y = None
# mul = _stack0 * add; _stack0 = add = None
# return (mul,)
# Run fn with AOT
torch._dynamo.reset()
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(optimized_mod)
aot_fn(x, y)
self.assertTrue(is_safe[0])
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()