| # Owner(s): ["module: dynamo"] | 
 |  | 
 | import re | 
 | import sys | 
 | from io import StringIO | 
 |  | 
 | import torch._dynamo.test_case | 
 | import torch._dynamo.testing | 
 | from torch._dynamo.comptime import comptime | 
 |  | 
 | # Because we don't support free variables in comptime at the moment, | 
 | # we have to communicate via globals.  This also means these tests cannot | 
 | # be run in parallel in a single process (not that you'd... ever want | 
 | # to do that?) | 
 | FILE = None | 
 | SELF = None | 
 |  | 
 |  | 
 | class ComptimeTests(torch._dynamo.test_case.TestCase): | 
 |     def test_print_graph(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_graph(verbose=False, file=FILE) | 
 |  | 
 |             # Test the compact notation doesn't error or graph break; | 
 |             # you'll have to visually inspect to see that it printed | 
 |             comptime.print_graph() | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         self.assertExpectedInline( | 
 |             FILE.getvalue().strip(), | 
 |             """\ | 
 | def forward(self, L_x_ : torch.Tensor): | 
 |     l_x_ = L_x_ | 
 |     y = l_x_ * 2;  l_x_ = None""", | 
 |         ) | 
 |  | 
 |     def test_print_disas(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_disas(file=FILE) | 
 |  | 
 |             comptime.print_disas() | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         def munge_disas(s): | 
 |             re.sub( | 
 |                 r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)", | 
 |                 "\1 \3", | 
 |                 s, | 
 |                 flags=re.MULTILINE, | 
 |             ) | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         out = FILE.getvalue() | 
 |         # Check that the instruction offset is working | 
 |         self.assertIn("-->", out) | 
 |         # Check that the bytecode resembles what we expect | 
 |         self.assertIn("STORE_FAST", out) | 
 |         if sys.version_info < (3, 11): | 
 |             self.assertIn("BINARY_MULTIPLY", out) | 
 |         else: | 
 |             self.assertIn("BINARY_OP", out) | 
 |  | 
 |     def test_print_value_stack(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         def g(x): | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_value_stack(file=FILE, stacklevel=1) | 
 |  | 
 |             return x | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x + g(x) | 
 |  | 
 |             return y + comptime.print_value_stack_and_return(y * 2) | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         self.assertExpectedInline( | 
 |             FILE.getvalue(), | 
 |             """\ | 
 | - TensorVariable() | 
 | """, | 
 |         ) | 
 |  | 
 |     def test_print_locals(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_locals(file=FILE) | 
 |  | 
 |             comptime.print_locals() | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         self.assertExpectedInline( | 
 |             FILE.getvalue(), | 
 |             """\ | 
 | x = TensorVariable() | 
 | y = TensorVariable() | 
 | """, | 
 |         ) | 
 |  | 
 |     def test_print_bt(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         def g(x): | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_bt(file=FILE) | 
 |  | 
 |             comptime.print_bt() | 
 |  | 
 |             return x + 3 | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |             y = g(y) | 
 |             return y + 3 | 
 |  | 
 |         def munge_filenames(s): | 
 |             return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         bt = FILE.getvalue() | 
 |         self.assertIn("y = g(y)", bt) | 
 |  | 
 |     def test_print_guards(self): | 
 |         global FILE | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.print_guards(file=FILE) | 
 |  | 
 |             comptime.print_guards() | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         self.assertExpectedInline( | 
 |             re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE), | 
 |             """\ | 
 |  | 
 |         local "L['x']" TENSOR_MATCH | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         global '' GRAD_MODE | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         global '' DETERMINISTIC_ALGORITHMS | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         global '' TORCH_FUNCTION_STATE | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         global '' DEFAULT_DEVICE | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         global '' BACKEND_MATCH | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         } | 
 |         shape_env '' SHAPE_ENV | 
 |         { | 
 |             'guard_types': None, | 
 |             'code': None, | 
 |             'obj_weakref': None | 
 |             'guarded_class': None | 
 |         }""", | 
 |         ) | 
 |  | 
 |     def test_graph_break(self): | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 pass | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         cnt.frame_count = 0 | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def g(x): | 
 |             y = x * 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 ctx.graph_break() | 
 |  | 
 |             y = y + 2 | 
 |  | 
 |             comptime.graph_break() | 
 |  | 
 |             return y * 3 | 
 |  | 
 |         g(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 3) | 
 |  | 
 |     def test_get_local(self): | 
 |         global SELF, FILE | 
 |         SELF = self | 
 |         FILE = StringIO() | 
 |         cnt = torch._dynamo.testing.CompileCounter() | 
 |  | 
 |         @torch._dynamo.optimize(cnt) | 
 |         def f(x): | 
 |             y = x * 2 | 
 |             lit = 2 | 
 |  | 
 |             @comptime | 
 |             def _(ctx): | 
 |                 y = ctx.get_local("y") | 
 |                 SELF.assertEqual(y.as_fake().size(0), 2) | 
 |                 SELF.assertEqual(y.size(0), 2) | 
 |                 # Trigger a graph write (TODO: this is not so | 
 |                 # useful right now as there's no way to make use | 
 |                 # of the output proxy; maybe it's useful for inserting | 
 |                 # side-effectful operations into the graph) | 
 |                 y.as_proxy() + 4 | 
 |                 ctx.print_graph(verbose=False, file=FILE) | 
 |                 SELF.assertIs(y.python_type(), torch.Tensor) | 
 |                 lit = ctx.get_local("lit") | 
 |                 SELF.assertEqual(lit.as_python_constant(), 2) | 
 |  | 
 |             return y + 3 | 
 |  | 
 |         f(torch.randn(2)) | 
 |         self.assertEqual(cnt.frame_count, 1) | 
 |         self.assertExpectedInline( | 
 |             FILE.getvalue().strip(), | 
 |             """\ | 
 | def forward(self, L_x_ : torch.Tensor): | 
 |     l_x_ = L_x_ | 
 |     y = l_x_ * 2;  l_x_ = None | 
 |     add = y + 4;  y = None""", | 
 |         ) | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     from torch._dynamo.test_case import run_tests | 
 |  | 
 |     run_tests() |