| # Owner(s): ["module: dynamo"] |
| import io |
| import warnings |
| from unittest.mock import patch |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import same |
| from torch._dynamo.utils import counters |
| |
| |
| class ReorderLogsTests(torch._dynamo.test_case.TestCase): |
| def test_dont_reorder_print(self): |
| def f(x): |
| x = x + x |
| print("moo") |
| x = x * x |
| return x |
| |
| counters.clear() |
| x = torch.randn(3, 3) |
| opt_f = torch.compile(backend="eager")(f) |
| with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: |
| opt_out = opt_f(x) |
| printed_output = mock_stdout.getvalue().strip() |
| orig_out = f(x) |
| |
| self.assertTrue(same(orig_out, opt_out)) |
| self.assertEqual(printed_output, "moo") |
| self.assertEqual(len(counters["graph_break"]), 1) |
| |
| @torch._dynamo.config.patch(reorderable_logging_functions={print}) |
| def test_reorder_print(self): |
| def f(x): |
| print("moo") |
| x1 = x + x |
| print(x1) |
| x2 = x1 * x1 |
| print(1, 2, 3) |
| x3 = x2 + x2 |
| return (x1, x3) |
| |
| x = torch.ones(3, 3) |
| opt_f = torch.compile(backend="eager", fullgraph=True)(f) |
| with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: |
| opt_out = opt_f(x) |
| printed_output = mock_stdout.getvalue().strip() |
| orig_out = f(x) |
| |
| self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3") |
| self.assertTrue(same(orig_out, opt_out)) |
| |
| @torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn}) |
| def test_reorder_warnings(self): |
| import warnings |
| |
| def f(x): |
| x1 = x + x |
| warnings.warn("moo") |
| x2 = x1 * x1 |
| warnings.warn(f"{x2}") |
| x3 = x2 + x2 |
| return x3 |
| |
| x = torch.ones(3, 3) |
| opt_f = torch.compile(backend="eager", fullgraph=True)(f) |
| with warnings.catch_warnings(record=True) as w: |
| opt_out = opt_f(x) |
| warning_messages = [str(i.message) for i in w] |
| orig_out = f(x) |
| |
| self.assertTrue(same(orig_out, opt_out)) |
| self.assertIn("moo", warning_messages) |
| |
| @torch._dynamo.config.patch(reorderable_logging_functions={print}) |
| def test_reorder_print_graph_break(self): |
| def f(x): |
| x1 = x + x |
| print(f"res: {x1}") |
| x2 = x1 * x1 |
| torch._dynamo.graph_break() |
| x3 = x2 + x2 |
| print(1, 2, 3) |
| return x3 |
| |
| x = torch.ones(3, 3) |
| opt_f = torch.compile(backend="eager")(f) |
| with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: |
| opt_out = opt_f(x) |
| printed_output = mock_stdout.getvalue().strip() |
| orig_out = f(x) |
| |
| self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3") |
| self.assertTrue(same(orig_out, opt_out)) |
| |
| def test_reorder_custom_log_fn(self): |
| custom_logs = [] |
| |
| def custom_log(s: str): |
| torch._dynamo.graph_break() |
| custom_logs.append(s) |
| |
| def f(x): |
| custom_log("moo") |
| x1 = x + x |
| custom_log(f"{x1}") |
| return x + x |
| |
| x = torch.ones(3, 3) |
| counters.clear() |
| with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}): |
| opt_f = torch.compile(backend="eager")(f) |
| opt_out = opt_f(x) |
| |
| self.assertEqual(sum(counters["graph_break"].values()), 1) |
| self.assertEqual(custom_logs[0], "moo") |
| self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}") |
| |
| @torch._dynamo.config.patch(reorderable_logging_functions={print}) |
| def test_constant_mutation(self): |
| def f(x): |
| alist = [x] |
| alist.append(x + 1) |
| print(alist[-1]) |
| alist[0].sum().item() # graph break |
| res = alist.pop() |
| print(alist[-1]) |
| res.sum().item() # graph break |
| return res |
| |
| inputs = (torch.tensor([1]),) |
| counters.clear() |
| opt_f = torch.compile(backend="eager")(f) |
| with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: |
| opt_out = opt_f(*inputs) |
| printed_output = mock_stdout.getvalue().strip() |
| orig_out = f(*inputs) |
| |
| self.assertEqual(printed_output, "tensor([2])\ntensor([1])") |
| self.assertTrue(same(orig_out, opt_out)) |
| |
| graph_break_key = counters["graph_break"].keys() |
| self.assertEqual(len(graph_break_key), 1) |
| self.assertEqual(next(iter(graph_break_key)), "Tensor.item") |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |