| # Owner(s): ["module: dynamo"] |
| import unittest |
| from unittest.mock import patch |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo import config |
| from torch._dynamo.testing import skipIfPy311, unsupported |
| from torch._dynamo.utils import disable_cache_limit, ifunspec |
| |
| globalmod = torch.nn.ReLU() |
| |
| |
| def indirectly_unsupported(a, b): |
| c = a + b |
| return unsupported(a, c) |
| |
| |
| class SubGraphTests(torch._dynamo.test_case.TestCase): |
| def _common(self, fn, frame_count, op_count): |
| torch._dynamo.reset() |
| v1 = torch.ones(10) |
| v2 = torch.ones(10) * -2.0 |
| correct1 = fn(v1, v2) |
| correct2 = fn(v2, v1) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| r1 = opt_fn(v1, v2) |
| r2 = opt_fn(v2, v1) |
| self.assertTrue(torch._dynamo.testing.same(r1, correct1)) |
| self.assertTrue(torch._dynamo.testing.same(r2, correct2)) |
| self.assertEqual( |
| cnt.frame_count, |
| frame_count, |
| f"actual {cnt.frame_count} != expected {frame_count}", |
| ) |
| self.assertEqual(cnt.op_count, op_count) |
| |
| def test_control_flow1(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| if c1.sum() > c2.sum(): |
| return c1 |
| else: |
| return c2 |
| |
| self._common(fn, 1, 5) |
| |
| def test_control_flow2(self): |
| def fn(a, b): |
| if a.sum() > b.sum(): |
| return 1 |
| else: |
| return 2 |
| |
| self._common(fn, 1, 3) |
| |
| def test_control_flow3(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| m = globalmod |
| if c1.sum() > c2.sum(): |
| return m(c1) |
| else: |
| return m(c2) |
| |
| self._common(fn, 3, 7) |
| |
| def test_control_flow4(self): |
| def fn(a, b): |
| tmp1 = a.sum() > b.sum() and a.sum() > 0 |
| if tmp1: |
| return 1 |
| else: |
| return 2 |
| |
| self._common(fn, 3, 5) |
| |
| def test_control_flow5(self): |
| def fn(a, b): |
| tmp1 = a.sum() > b.sum() and a.sum() > 0 |
| tmp2 = a.sum() < b.sum() or b.sum() > 0 |
| if tmp1 and tmp2: |
| return 1, tmp1, tmp2 |
| else: |
| return 2, tmp1, tmp2 |
| |
| self._common(fn, 6, 13) |
| |
| def test_capi_call1(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| return unsupported(c1, c2) |
| |
| self._common(fn, 1, 2) |
| |
| def test_capi_call2(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| return a - (b - unsupported(c1, c2)) |
| |
| self._common(fn, 2, 4) |
| |
| def test_capi_call3(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| return torch._dynamo.testing.unsupported(c1, c2) |
| |
| self._common(fn, 1, 2) |
| |
| def test_indirect_unsupported1(self): |
| def fn(a, b): |
| c1 = a - b |
| c2 = b - a |
| return indirectly_unsupported(c1, c2) |
| |
| self._common(fn, 2, 3) |
| |
| def test_indirect_unsupported2(self): |
| def fn(a, b): |
| local_const1 = 7 |
| local_const2 = 22 |
| c1 = a - b |
| c2 = b - a |
| return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) |
| |
| self._common(fn, 3, 5) |
| |
| def test_indirect_unsupported3(self): |
| def fn(a, b): |
| args = [a - b, b - a] |
| return indirectly_unsupported(*args) |
| |
| self._common(fn, 2, 3) |
| |
| def test_stack_state1(self): |
| def fn(a, b): |
| t1 = 1.23 * a |
| t2 = 4.56 * a |
| c1 = a - b |
| c2 = b - a |
| return t1 / (t2 - unsupported(c1, c2)) |
| |
| self._common(fn, 2, 6) |
| |
| def test_stack_state2(self): |
| def fn(a, b): |
| t1 = 1.23 * a |
| t2 = 4.56 * a |
| c1 = a - b |
| c2 = b - a |
| return t1 / (t2 - indirectly_unsupported(c1, c2)) |
| |
| self._common(fn, 3, 7) |
| |
| def test_multigraph(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| if x.sum() < 0: |
| return x * -1.0 |
| return x |
| |
| self._common(fn, 2, 5) |
| |
| def test_extended_args(self): |
| too_many_adds = "+".join(["a", "b"] * 256) |
| source = ( |
| f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" |
| ) |
| self._common(eval(source), 3, 1026) |
| |
| def test_resume1(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| x = x + 2.0 |
| x = unsupported(x, a) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 2, 6) |
| |
| def test_resume2(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| x = x + 2.0 |
| x = indirectly_unsupported(x, a) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 3, 7) |
| |
| def test_resume3(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| x = x + 2.0 |
| x = indirectly_unsupported(x, b=a) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 3, 7) |
| |
| def test_resume4(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| x = x + 2.0 |
| x = indirectly_unsupported(a=x, b=a) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 3, 7) |
| |
| def test_resume5(self): |
| def fn(a, b): |
| x = a + b |
| x = x / 2.0 |
| x = x + 2.0 |
| print(x) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 2, 6) |
| |
| def test_start1(self): |
| def fn(a, b): |
| print(a) |
| x = a + b |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 1, 3) |
| |
| def test_start2(self): |
| def fn(a, b): |
| x = indirectly_unsupported(a, b) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 2, 4) |
| |
| def test_start3(self): |
| def fn(a, b): |
| x = unsupported(a, b) |
| x = x + 2.0 |
| x = x + 2.0 |
| x = x + 2.0 |
| return x |
| |
| self._common(fn, 1, 3) |
| |
| def test_start4(self): |
| def fn(a, b, check): |
| if check: |
| return a + b + 10 |
| else: |
| return a + b - 10 |
| |
| v1 = torch.randn(10) |
| v2 = torch.randn(10) |
| f = torch.zeros(1, dtype=torch.int32) |
| t = torch.ones(1, dtype=torch.int32) |
| correct1 = fn(v1, v2, t) |
| correct2 = fn(v1, v2, f) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| r1 = opt_fn(v1, v2, t) |
| r2 = opt_fn(v1, v2, f) |
| self.assertTrue(torch._dynamo.testing.same(r1, correct1)) |
| self.assertTrue(torch._dynamo.testing.same(r2, correct2)) |
| self.assertEqual(cnt.frame_count, 3) |
| self.assertEqual(cnt.op_count, 4) |
| |
| def test_resume_freevars(self): |
| c1 = torch.randn(10) |
| c2 = torch.randn(10) |
| |
| def fn(a, b): |
| x = a + b + (c1 - c2) |
| x = unsupported(x, x) |
| return x + (c1 - c2) |
| |
| self._common(fn, 2, 5) |
| |
| def test_restore_state(self): |
| def fn(a, b): |
| len_ = len |
| x = a + b |
| x = torch.add(unsupported(x, x), 1) |
| return a * x + len_(b) |
| |
| if config.dynamic_shapes: |
| self._common(fn, 2, 5) |
| else: |
| self._common(fn, 2, 4) |
| |
| def test_restore_range(self): |
| def fn(a, b): |
| x = a + b |
| rng = range(3, 8, 2) |
| x = unsupported(x, x) |
| for i in rng: |
| x = x + i |
| return x |
| |
| # We don't specialize on range with dynamic shapes, which |
| # means we fail to unroll the loop. |
| # TODO: Consider forcing specialization when we iterate over |
| # the loop |
| self._common(fn, 2, ifunspec(1, 4)) |
| |
| def test_restore_range_iter(self): |
| def fn(a, b): |
| x = a + b |
| rng = iter(range(3, 8, 2)) |
| x = unsupported(x, x) |
| x += next(rng) |
| return x, list(rng) |
| |
| self._common(fn, 2, 2) |
| |
| def test_pop_after_resume(self): |
| def fn(a, b): |
| tmp = [a + 1, b + 2, a + b] |
| x = a |
| x = unsupported(x, x) |
| for i in range(3): |
| x += tmp.pop(-1) |
| return x |
| |
| self._common(fn, 2, 6) |
| |
| @disable_cache_limit() |
| def test_dynamic_shapes(self): |
| if config.assume_static_by_default: |
| return unittest.skip("Already covered identically in test_dynamic_kwarg") |
| |
| def fn(a, b): |
| return a - b * 10 |
| |
| torch._dynamo.reset() |
| cnt_static = torch._dynamo.testing.CompileCounter() |
| with patch("torch._dynamo.config.dynamic_shapes", False): |
| opt_fn = torch._dynamo.optimize(cnt_static)(fn) |
| for i in range(2, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| self.assertEqual(cnt_static.frame_count, 10) |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| with patch("torch._dynamo.config.dynamic_shapes", True): |
| opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) |
| # NB: must not do 0, 1 as they specialized |
| for i in range(2, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| # just one graph now rather than 10 |
| self.assertEqual(cnt_dynamic.frame_count, 1) |
| |
| @patch("torch._dynamo.config.dynamic_shapes", True) |
| @patch("torch._dynamo.config.assume_static_by_default", False) |
| def test_dynamic_getitem(self): |
| def fn(a, b): |
| return a[b.size(0) - 1] |
| |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| for i in range(3, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| # just one graph |
| self.assertEqual(cnt.frame_count, 1) |
| |
| def test_dynamic_kwarg(self): |
| def fn(a, b): |
| return a - b * 10 |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) |
| start = 2 |
| end = 12 |
| steps = end - start |
| for i in range(start, end): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| |
| if config.assume_static_by_default: |
| # 2 graph breaks - 1 static, 1 made dynamic via automatic |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| else: |
| # just one graph |
| self.assertEqual(cnt_dynamic.frame_count, 1) |
| |
| def test_dynamic_duck_size(self): |
| def fn(a, b): |
| if a.size(0) == b.size(0): |
| return a + b |
| else: |
| return a.sum() + b.sum() |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) |
| x = torch.randn(2) |
| y = torch.randn(3) |
| self.assertEqual(opt_fn(x, x), fn(x, x)) |
| self.assertEqual(opt_fn(x, y), fn(x, y)) |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| |
| def test_dynamic_order_dependence(self): |
| def fn(a, b): |
| return a.sum() + b.sum() |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) |
| x = torch.randn(2) |
| y = torch.randn(3) |
| self.assertEqual(opt_fn(x, y), fn(x, y)) |
| self.assertEqual(opt_fn(x, x), fn(x, x)) |
| # NB: This COULD validly be 2, but we don't test disjointness in the |
| # guards for when x and y didn't duck size together, so we end up |
| # with a generic graph that also works when x and y happen to duck |
| # size together. |
| if config.assume_static_by_default: |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| else: |
| self.assertEqual(cnt_dynamic.frame_count, 1) |
| |
| torch._dynamo.reset() |
| cnt_dynamic.frame_count = 0 |
| self.assertEqual(opt_fn(x, x), fn(x, x)) # this overspecializes! |
| self.assertEqual(opt_fn(x, y), fn(x, y)) |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| |
| def test_dynamic_zero_inference(self): |
| def fn(a): |
| if a.size(0) != 0: |
| return a * 2 |
| else: |
| return a + 1 |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) |
| x = torch.randn(0) |
| y = torch.randn(2) |
| self.assertEqual(opt_fn(y), fn(y)) |
| self.assertEqual(opt_fn(x), fn(x)) |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_no_graph_break_on_item(self): |
| def fn(a, b): |
| x = a + b - 1.5 |
| x = x.sum() |
| x.item() |
| x = x / (a + b) |
| return x |
| |
| self._common(fn, 1, 6) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) |
| def test_graph_break_on_item(self): |
| def fn(a, b): |
| x = a + b - 1.5 |
| x = x.sum() |
| x.item() |
| x = x / (a + b) |
| return x |
| |
| self._common(fn, 2, 5) |
| |
| def test_resume_paths_join(self): |
| def fn(x, c1, c2, c3): |
| x = x + 1 |
| if c1: |
| x = x + 2 |
| x = x + 3 |
| if c2: |
| x = x + 4 |
| x = x + 5 |
| if c3: |
| x = x + 6 |
| return x + 7 |
| |
| v1 = torch.randn(10) |
| t = torch.Tensor([True]) |
| f = torch.Tensor([False]) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| for a in (t, f): |
| for b in (t, f): |
| for c in (t, f): |
| opt_fn(v1, a, b, c) |
| |
| # checking here we don't create 2^n graphs |
| self.assertEqual(cnt.frame_count, 7) |
| self.assertEqual(cnt.op_count, 10) |
| |
| @skipIfPy311 |
| def test_resume_with_no_grad1(self): |
| def fn(a, b): |
| x = a + b |
| with torch.no_grad(): |
| x = x + 1 |
| x.sum().tolist() # graph break |
| x = x + 2 |
| x = x + 3 |
| return x |
| |
| self._common(fn, 2, 9) |
| torch._dynamo.reset() |
| with torch.no_grad(): |
| self._common(fn, 2, 9) |
| |
| @skipIfPy311 |
| def test_resume_with_no_grad2(self): |
| def fn(a, b): |
| x = a + b |
| with torch.no_grad(): |
| x = x + 1 |
| x.sum().tolist() # graph break |
| x = x + 2 |
| x.sum().tolist() # graph break |
| x = x + 3 |
| x = x + 4 |
| return x |
| |
| self._common(fn, 3, 13) |
| |
| @skipIfPy311 |
| def test_resume_with_no_grad3(self): |
| def fn(a, b): |
| x = a + b |
| with torch.no_grad(): |
| with torch.no_grad(): |
| x = x + 1 |
| with torch.enable_grad(): |
| x.sum().tolist() # graph break |
| x = x[0] + 2 |
| x = x + 3 |
| x = x + 4 |
| return x |
| |
| self._common(fn, 2, 19) |
| |
| def test_resume_tuple_iterator(self): |
| def fn(a, b): |
| x = a + b |
| it = iter(tuple(range(10))) |
| x = x + next(it) |
| x = x + next(it) |
| x = x + next(it) |
| x = unsupported(x, x) |
| x = x + next(it) |
| x = x + next(it) |
| x = x + next(it) |
| x = x + next(it) |
| return x |
| |
| self._common(fn, 2, 8) |
| |
| def test_tuple_iterator_return(self): |
| def fn(x): |
| it = iter(tuple(range(10))) |
| x = x + next(it) |
| x = x + next(it) |
| x = unsupported(x, x) |
| x = x + next(it) |
| x = x + next(it) |
| x = unsupported(x, x) |
| x = x + next(it) |
| x = x + next(it) |
| return x, it |
| |
| v1 = torch.randn(10) |
| v2, it2 = fn(v1) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| v3, it3 = opt_fn(v1) |
| v4, it4 = opt_fn(v1) |
| self.assertEqual(v2.tolist(), v3.tolist()) |
| self.assertEqual(v2.tolist(), v4.tolist()) |
| self.assertEqual(list(it2), list(it3)) |
| self.assertEqual(cnt.frame_count, 3) |
| self.assertEqual(cnt.op_count, 6) |
| |
| @unittest.skip("not working yet") |
| def test_tuple_iterator_mutate(self): |
| def fn(x, it): |
| x = x + next(it) |
| x = x + next(it) |
| x = x + next(it) |
| x = x + next(it) |
| return x |
| |
| v1 = torch.randn(10) |
| it1 = iter(tuple(range(10))) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist()) |
| self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9]) |
| |
| def test_enumerate_not_break_graph(self): |
| def fn(a, b): |
| for i, x in enumerate(a.shape): |
| b = b + x |
| for i, x in enumerate(b.shape, 8): |
| b = b + x * i |
| return b |
| |
| self._common(fn, 1, 2) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |