blob: ad0382f3d24ecbd84f65239b3cfd417c5e2dcafc [file] [log] [blame]
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class RecompileTests(torch._dynamo.test_case.TestCase):
def test_automatic_dynamic_reduce_recompiles(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch._dynamo.optimize(cnt)(foo)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
opt(x, y)
x = torch.randn([4])
y = torch.randn([4])
opt(x, y)
opt(x, y)
x = torch.randn([5])
y = torch.randn([5])
opt(x, y)
opt(x, y)
x = torch.randn([6])
y = torch.randn([6])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", False)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def test_recompiles_true_false_flop(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
if x:
return y * 2
else:
return y * y
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
x = True
y = torch.randn([2])
opt(x, y)
x = False
y = torch.randn([2])
opt(x, y)
x = True
y = torch.randn([3])
opt(x, y)
x = True
y = torch.randn([4])
opt(x, y)
x = True
y = torch.randn([5])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", False)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 3)
self.assertEqual(with_automatic.op_count, 3)