| # Owner(s): ["module: dynamo"] |
| import collections |
| import copy |
| import dataclasses |
| import dis |
| import enum |
| import logging |
| import math |
| import os |
| import sys |
| import typing |
| import unittest |
| import weakref |
| from unittest.mock import patch |
| |
| import numpy as np |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| import torch.onnx.operators |
| from torch._dynamo import bytecode_transformation, graph_break |
| from torch._dynamo.testing import ( |
| CompileCounter, |
| requires_static_shapes, |
| same, |
| unsupported, |
| ) |
| from torch.testing._internal.common_utils import freeze_rng_state |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) |
| |
| |
| def my_custom_function(x): |
| return x + 1 |
| |
| |
| class MiscTests(torch._dynamo.test_case.TestCase): |
| def test_boolarg(self): |
| def boolarg(aa, bb, flag): |
| if flag: |
| return aa - bb |
| else: |
| return bb - aa |
| |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| correct1 = boolarg(a, b, True) |
| correct2 = boolarg(a, b, False) |
| correct3 = boolarg(a, b, None) |
| counter = CompileCounter() |
| opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg) |
| val1 = opt_boolarg(a, b, True) |
| val2 = opt_boolarg(a, b, False) |
| val3 = opt_boolarg(a, b, None) |
| val4 = opt_boolarg(a, b, True) |
| self.assertTrue(same(val1, correct1)) |
| self.assertTrue(same(val2, correct2)) |
| self.assertTrue(same(val3, correct3)) |
| self.assertTrue(same(val4, correct1)) |
| self.assertEqual(counter.frame_count, 3) |
| |
| def test_callpacked(self): |
| def call_packed(args): |
| a, b, c = args |
| return a - b * c |
| |
| counter = CompileCounter() |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| c = torch.randn(10, 10) |
| correct = call_packed([a, b, c]) |
| opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed) |
| val1 = opt_call_packed([a, b, c]) |
| val2 = opt_call_packed((a, b, c)) |
| val3 = opt_call_packed([a, b, c]) |
| val4 = opt_call_packed((a, b, c)) |
| self.assertTrue(same(val1, correct)) |
| self.assertTrue(same(val2, correct)) |
| self.assertTrue(same(val3, correct)) |
| self.assertTrue(same(val4, correct)) |
| self.assertEqual(counter.frame_count, 2) |
| |
| def test_raises(self): |
| def fn(a, b, c, cls): |
| x = a + b - c * 10 |
| raise cls(str(x)) |
| |
| counter = CompileCounter() |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| c = torch.randn(10, 10) |
| opt_fn = torch._dynamo.optimize(counter)(fn) |
| self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError)) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 3) |
| |
| def test_inplace(self): |
| def inplace1(a, b): |
| o = torch.empty((10, 10)) |
| o.copy_(a) |
| o -= b |
| return o |
| |
| torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) |
| |
| def test_unpack4(self): |
| def unpack4(a, b): |
| a = a[:5, :] |
| b = b[:5, :] |
| x, y = a.size() |
| o = torch.empty((x, y)) |
| o.copy_(a / b) |
| return o |
| |
| torch._dynamo.testing.standard_test( |
| self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| def test_unpack5(self): |
| def unpack5(a, b): |
| a = a[:5, :] |
| b = b[:5, :] |
| x, y = a.shape |
| o = torch.empty((x, y)) |
| o.copy_(a / b) |
| return o |
| |
| torch._dynamo.testing.standard_test( |
| self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| def test_matmul1(self): |
| def matmul_op1(a, b): |
| return a @ b |
| |
| # TODO(jansel): FX doesn't support this, should add upstream support |
| torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1) |
| |
| def test_builtin_isinstance(self): |
| def fn(x): |
| t = torch.arange(1, 3) |
| a = isinstance(x, torch.Tensor) |
| b = isinstance(t, torch.Tensor) |
| c = isinstance(x, int) |
| d = isinstance(3, int) |
| e = isinstance([1, 2, 3], list) |
| f = isinstance({"foo": 1, "bar": 2}, dict) |
| res = [a, b, c, d, e, f] |
| # Can't run yet due to other unimplemented instructions |
| # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)] |
| return res |
| |
| torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) |
| |
| def test_fold(self): |
| def fn(a): |
| return a + math.sqrt(63) |
| |
| torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) |
| |
| def test_shape_unpack(self): |
| def fn(x): |
| a, b = x.size() |
| return x * b |
| |
| i = torch.randn(5, 10) |
| r1 = fn(i) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| r2 = opt_fn(i) |
| self.assertTrue(same(r1, r2)) |
| |
| def test_empty_list(self): |
| def fn(x, ll): |
| if len(ll) == 0 and not ll and ll is not None: |
| return x + 1 |
| |
| i = torch.randn(5, 10) |
| r1 = fn(i, []) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| r2 = opt_fn(i, []) |
| r3 = opt_fn(i, tuple()) |
| self.assertTrue(same(r1, r2)) |
| self.assertTrue(same(r1, r3)) |
| |
| def test_config_obj(self): |
| class Cfg: |
| def __init__(self): |
| self.val = 0.5 |
| self.count = 3 |
| |
| def fn(x, cfg): |
| for i in range(cfg.count): |
| x = x + cfg.val |
| return x |
| |
| cfg1 = Cfg() |
| cfg1.val = 1.0 |
| cfg2 = Cfg() |
| v = torch.zeros(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| v = opt_fn(v, cfg1) # 3 |
| v = opt_fn(v, cfg2) # 4.5 |
| cfg2.count = 1 |
| v = opt_fn(v, cfg2) # 5 |
| cfg2.val = 2.0 |
| v = opt_fn(v, cfg2) # 7 |
| self.assertEqual(v[0], 7) |
| self.assertEqual(cnts.op_count, 8) |
| |
| def test_config_getattr_default(self): |
| class Cfg: |
| def __init__(self): |
| self.val = 0.5 |
| self.count = 10 |
| |
| def fn(x, cfg): |
| if getattr(cfg, "just_add_7", False): |
| return x + 7 |
| for i in range(cfg.count): |
| x = x + cfg.val |
| return x |
| |
| cfg1 = Cfg() |
| v = torch.zeros(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| cfg1.just_add_7 = True |
| self.assertEqual(opt_fn(v, cfg1)[0], 7) |
| self.assertEqual(opt_fn(v, cfg1)[0], 7) |
| cfg1.just_add_7 = False |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(cnts.frame_count, 3) |
| |
| def test_size_input(self): |
| def fn(x, s): |
| a, b = s |
| return x + (a - b) |
| |
| v = torch.zeros(10, 20) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v, v.size())[0, 0], -10) |
| self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) |
| self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_cell_output1(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| out = a + b * 10 |
| |
| v = torch.Tensor([100]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIsNone(opt_fn(v, v)) |
| self.assertEqual(out[0], 1100) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_cell_output2(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| c = unsupported(a, b) |
| out = a + b * 10 + c |
| |
| v = torch.Tensor([100]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIsNone(opt_fn(v, v)) |
| self.assertEqual(out[0], 1200) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_return_nested_function(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| c = a + b |
| d = a + 1.0 |
| |
| def fn2(f: int = 7, g: float = 9.0): |
| nonlocal out |
| out = a + b * 10 |
| return c * f - d * g |
| |
| return fn2 |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2)) |
| self.assertEqual(opt_fn_ret(1.5)[0], -459) |
| self.assertEqual(out[0], 2100) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 7) |
| |
| def test_tensor_dict1(self): |
| def fn(inputs): |
| return inputs["a"] - inputs["b"] * 1.5 |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_tensor_dict2(self): |
| def fn1(inputs): |
| total = torch.zeros(1) |
| for k, v in inputs.items(): |
| total += v |
| return total |
| |
| def fn2(inputs): |
| total = torch.zeros(1) |
| for v in inputs.values(): |
| total += v |
| return total |
| |
| def fn3(inputs): |
| total = torch.zeros(1) |
| for k in inputs.keys(): |
| total += inputs[k] |
| return total |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize(cnts)(fn1) |
| opt_fn2 = torch._dynamo.optimize(cnts)(fn2) |
| opt_fn3 = torch._dynamo.optimize(cnts)(fn3) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(cnts.frame_count, 3) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_dictcomp(self): |
| def fn1(inputs): |
| return {k: v + 1 for k, v in inputs.items()} |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize(cnts)(fn1) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_listcomp(self): |
| def fn2(inputs): |
| return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0)) |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn2 = torch._dynamo.optimize(cnts)(fn2) |
| self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 4) |
| |
| def test_is_floating_point(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if torch.is_floating_point(b): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_is_floating_point2(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if b.is_floating_point(): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_is_tensor(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if torch.is_tensor(b): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_numel(self): |
| def fn(a): |
| return a + a.numel() + torch.numel(a) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=2, expected_ops_dynamic=4 |
| ) |
| |
| def test_pair(self): |
| def fn(a): |
| return ( |
| torch.zeros(torch.nn.modules.utils._pair(a.size())) |
| + a |
| + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum() |
| ) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_tensor_item_capture(self): |
| def fn(a, b): |
| return (a + b).sum().item() |
| |
| v1 = torch.randn((10, 10)) |
| v2 = torch.randn((10, 10)) |
| correct = fn(v1, v2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize((cnts))(fn) |
| self.assertEqual(opt_fn(v1, v2), correct) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) |
| def test_tensor_item_no_capture(self): |
| def fn(a, b): |
| return (a + b).sum().item() |
| |
| v1 = torch.randn((10, 10)) |
| v2 = torch.randn((10, 10)) |
| correct = fn(v1, v2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize((cnts))(fn) |
| self.assertEqual(opt_fn(v1, v2), correct) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_namedtuple1(self): |
| def fn(a, b): |
| tmp = mytuple(a, b, a + b) |
| return mytuple(tmp.a, tmp[1], tmp.ab + b) |
| |
| v1 = torch.Tensor([10]) |
| v2 = torch.Tensor([20]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v1, v2).ab, 50) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_namedtuple2(self): |
| def fn(packed): |
| a, b, c = packed |
| if hasattr(packed, "b"): |
| b = packed.b + 1 |
| c = packed[2] |
| return a + b + c |
| |
| v1 = torch.Tensor([1]) |
| v2 = torch.Tensor([2]) |
| v3 = torch.Tensor([3]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_range_input(self): |
| def fn(a, rng): |
| x = a |
| for i in rng: |
| x = x + i |
| return x |
| |
| def fn1(a): |
| return fn(a, rng=range(3)) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn1, nargs=1, expected_ops=3 |
| ) |
| |
| def test_no_grad(self): |
| def fn1(a, b): |
| x = a + 1 |
| # redundant no_grad should get ignored |
| with torch.no_grad(): |
| x = x + b |
| x = x + 2 |
| return x |
| |
| def fn2(a, b): |
| x = a + 1 |
| with torch.set_grad_enabled(False): |
| x = x + b |
| x = x + 2 |
| return x |
| |
| def fn3(a, b): |
| x = a + 1 |
| with torch.enable_grad(): |
| x = x + b |
| x = x + 2 |
| return x |
| |
| def fn4(a, b): |
| x = a + 1 |
| with torch.set_grad_enabled(True): |
| if torch.is_grad_enabled(): |
| x = x + b |
| x = x + 2 |
| return x |
| |
| with torch.no_grad(): |
| torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5) |
| with torch.enable_grad(): |
| torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5) |
| torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5) |
| |
| def test_grad_mode_guard(self): |
| def fn(a, b): |
| prev_grad = torch.is_grad_enabled() |
| torch.set_grad_enabled(False) |
| a = a + 1 |
| a.tolist() # graph break |
| ret = a + b |
| torch.set_grad_enabled(prev_grad) |
| return ret |
| |
| a = torch.randn([3, 4]) |
| b = torch.randn([3, 4]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| for _ in range(10): |
| opt_fn(a, b) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_build_tuple_unpack(self): |
| def fn1(a, b, c): |
| return a - b / c |
| |
| def fn2(a, b, c): |
| tmp1 = (a,) |
| tmp2 = (b, c) |
| args = (*tmp1, *tmp2) |
| return fn1(*args) |
| |
| def fn3(a, *args): |
| return fn1(a, *args) |
| |
| torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2) |
| torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2) |
| |
| def test_list_mul(self): |
| def fn(count): |
| head_mask = count * [None] * count |
| return head_mask |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(2), [None] * 4) |
| self.assertEqual(cnts.frame_count, 0) |
| self.assertEqual(cnts.op_count, 0) |
| |
| # KeyError: '__name__' |
| @patch.object(torch._dynamo.config, "suppress_errors", True) |
| def test_user_getattr1(self): |
| class MyConfig(dict): |
| def __getattr__(self, name): |
| return self[name] |
| |
| def fn(cfg, x, y): |
| return x + y + cfg.offset |
| |
| x = torch.randn(10) |
| cfg = MyConfig(offset=5) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_user_getattr2(self): |
| class MyConfig: |
| defined_on_class = 1 |
| |
| def __init__(self): |
| self.defined_on_object = 2 |
| |
| def __getattr__(self, name): |
| return 3 |
| |
| def fn(cfg, x): |
| return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined |
| |
| x = torch.randn(10) |
| cfg = MyConfig() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_user_property(self): |
| class MyConfig: |
| @property |
| def prop5(self): |
| return 5 |
| |
| def fn(cfg, x, y): |
| return x + y + cfg.prop5 |
| |
| x = torch.randn(10) |
| cfg = MyConfig() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_dataclass_fields(self): |
| @dataclasses.dataclass |
| class MyDataClass: |
| a: torch.Tensor |
| b: torch.Tensor = None |
| c: torch.Tensor = None |
| d: torch.Tensor = None |
| e: torch.Tensor = None |
| |
| def fn(obj): |
| class_fields = dataclasses.fields(obj) |
| assert len(class_fields) |
| assert all(field.default is None for field in class_fields[1:]) |
| other_fields_are_none = all( |
| getattr(obj, field.name) is None for field in class_fields[1:] |
| ) |
| assert not other_fields_are_none |
| |
| total = getattr(obj, class_fields[0].name) |
| for field in class_fields[1:]: |
| v = getattr(obj, field.name) |
| if v is not None: |
| total += v |
| |
| return total |
| |
| obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10)) |
| obj2 = MyDataClass(torch.randn(10), e=torch.randn(10)) |
| correct1 = fn(obj1) |
| correct2 = fn(obj2) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(obj1), correct1)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(obj2), correct2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 1) |
| |
| @requires_static_shapes |
| def test_tensor_build_list_unpack(self): |
| def fn(x): |
| # seen in fastNLP_Bert |
| return torch.cat([*x], dim=-1) |
| |
| val = torch.randn([1, 1, 473, 768]) |
| correct = fn(val) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(val), correct)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_numpy_int_constant(self): |
| def fn(x, a, b): |
| return x + (a % b) |
| |
| args = [torch.randn(10), 4096, np.int64(8)] |
| correct = fn(*args) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(*args), correct)) |
| self.assertTrue(same(opt_fn(*args), correct)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_dict_mutation_side_effect(self): |
| def fn(d): |
| d["c"] = d["a"] + d.pop("b") |
| return d |
| |
| args1 = {"a": torch.randn(10), "b": torch.randn(10)} |
| args2 = dict(args1) |
| assert fn(args1) is args1 |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIs(opt_fn(args2), args2) |
| self.assertTrue(same(args1, args2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_module_deepcopy(self): |
| m1 = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ) |
| m2 = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ) |
| |
| def fn(m, x): |
| m_copy = copy.deepcopy(m) |
| return m_copy(x) |
| |
| v = torch.randn(10) |
| correct1 = fn(m1, v) |
| correct2 = fn(m2, v) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| for _ in range(10): |
| self.assertTrue(same(opt_fn(m1, v), correct1)) |
| for _ in range(10): |
| self.assertTrue(same(opt_fn(m2, v), correct2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 4) |
| |
| def test_type_copy(self): |
| def fn(seq): |
| a, b = seq |
| return type(seq)([a + 1, b + 2, a + b]) |
| |
| args1 = [torch.randn(10), torch.randn(10)] |
| args2 = (torch.randn(10), torch.randn(10)) |
| correct1 = fn(args1) |
| correct2 = fn(args2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(args1), correct1)) |
| self.assertTrue(same(opt_fn(args2), correct2)) |
| self.assertIsInstance(opt_fn(args1), list) |
| self.assertIsInstance(opt_fn(args2), tuple) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 6) |
| |
| def test_setattr_mutation1(self): |
| class MyObj: # noqa: B903 |
| def __init__(self, a, b): |
| self.a = a |
| self.b = b |
| |
| def fn(obj): |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| obj.c = obj.a * obj.b + 4 |
| obj.b = obj.a * obj.c + 5 |
| obj.a = obj.b * obj.c + 6 |
| return obj |
| |
| x1 = torch.randn(10) |
| x2 = torch.randn(10) |
| obj1 = MyObj(x1, x2) |
| obj2 = MyObj(x1, x2) |
| fn(obj2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIs(opt_fn(obj1), obj1) |
| self.assertTrue(same(obj1.a, obj2.a)) |
| self.assertTrue(same(obj1.b, obj2.b)) |
| self.assertTrue(same(obj1.c, obj2.c)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 12) |
| |
| def test_setattr_mutation2(self): |
| class MyObj: |
| def __init__(self, x): |
| self.a = x + 1 |
| self.b = x + 2 |
| |
| def fn(x): |
| x = x / 3.0 |
| obj = MyObj(x) |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| return obj |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| obj1 = opt_fn(x1) |
| self.assertTrue(same(obj1.a, obj2.a)) |
| self.assertTrue(same(obj1.b, obj2.b)) |
| self.assertTrue(same(obj1.c, obj2.c)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_setattr_mutation3(self): |
| # TODO(jansel): dead code eliminate the object creation |
| class MyObj: |
| def __init__(self, x): |
| super().__init__() |
| self.a = x + 1 |
| self.b = x + 2 |
| |
| def fn(x): |
| x = x / 3.0 |
| obj = MyObj(x) |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| return obj.a, obj.b, obj.c |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| obj1 = opt_fn(x1) |
| self.assertTrue(same(obj1, obj2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_user_defined_class_name(self): |
| class MyClassFoo: |
| pass |
| |
| def fn1(a, b, c): |
| tmp = MyClassFoo() |
| if tmp.__class__.__name__ == "MyClassFoo": |
| return a - b / c |
| |
| torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3) |
| |
| def test_manual_seed(self): |
| def fn(a, b): |
| x = a + b |
| torch.manual_seed(9000) |
| return x + 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_usr_cls_staticmethod(self): |
| class Foo: |
| @staticmethod |
| def bar(a, b): |
| return a + b |
| |
| def fn(a, b): |
| return Foo.bar(a, b) - 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) |
| |
| def test_usr_cls_classmethod(self): |
| class Foo: |
| @classmethod |
| def bar(cls, a, b): |
| return a + b |
| |
| def fn(a, b): |
| return Foo.bar(a, b) - 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) |
| |
| def test_dunder_methods(self): |
| class Foo: |
| def __init__(self, val): |
| super().__init__() |
| self.val = val |
| |
| def __add__(self, other): |
| return Foo(self.val + other.val) |
| |
| def __mul__(self, other): |
| return Foo(self.val * other.val) |
| |
| def __truediv__(self, other): |
| return Foo(self.val / other.val) |
| |
| def __sub__(self, other): |
| return Foo(self.val - other.val) |
| |
| def fn(a, b, c): |
| return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b) |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4) |
| |
| def test_function_annotation(self): |
| class Variable: |
| pass |
| |
| def fn(x): |
| x = x / 3.0 |
| |
| def inner(y: typing.List[Variable]): |
| return x + 1 |
| |
| return inner |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1)([]) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1)) |
| obj1 = opt_fn_inner([]) |
| self.assertTrue(same(obj1, obj2)) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_nested_closure(self): |
| v0 = torch.randn(10) |
| |
| def fn1(): |
| v1 = torch.randn(10) |
| |
| def fn2(*args, **kwargs): |
| assert len(args) == 1 |
| assert len(kwargs) == 1 |
| v2 = torch.randn(10) + args[0] + kwargs["b"] |
| |
| def fn3(v3=torch.randn(10)): |
| def fn4(): |
| return v0 + v1 + v2 + v3 + 1 |
| |
| return fn4 |
| |
| return fn3 |
| |
| return fn2(1, b=2)() |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) |
| tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| self.assertTrue(tmp1().shape, (10,)) |
| self.assertTrue(same(tmp1(), tmp1())) |
| self.assertFalse(same(tmp1(), tmp2())) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_nested_closure_mutation(self): |
| def fn1(): |
| v1 = torch.randn(10) |
| |
| def fn2(): |
| v2 = torch.randn(10) |
| |
| def fn3(): |
| nonlocal v1, v2 |
| v1 += 1 |
| v2 += 2 |
| return v1 + v2 |
| |
| return fn3 |
| |
| rv = fn2() |
| rv() |
| rv() |
| return rv |
| |
| torch.manual_seed(9000) |
| counter1 = fn1() |
| result1 = [counter1(), counter1(), counter1()] |
| |
| torch.manual_seed(9000) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) |
| counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| result2 = [counter2(), counter2(), counter2()] |
| result1.append(counter1()) |
| result2.append(counter2()) |
| |
| self.assertTrue(same(result1, result2)) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 11) |
| |
| def test_write_to_closures_in_inlining(self): |
| out = [] |
| for use_dynamo in [False, True]: |
| |
| def make_counter(): |
| x = torch.randn(10) |
| |
| def counter(): |
| nonlocal x |
| x = x + 1 |
| return x |
| |
| return counter |
| |
| torch.manual_seed(0) |
| counter = make_counter() |
| if not use_dynamo: |
| out.append(counter() + counter()) |
| else: |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts, nopython=True) |
| def fn(counter): |
| return counter() + counter() |
| |
| out.append(fn(counter)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| self.assertFalse(same(counter() + counter(), out[-1])) |
| |
| self.assertTrue(same(out[0], out[1])) |
| |
| def test_top_package_import(self): |
| def fn(x): |
| import torch.fx |
| |
| assert not isinstance(x, torch.fx.Proxy) |
| return torch.sin(x) |
| |
| x = torch.randn(4, 5) |
| ref = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_optimize_on_module(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def custom_member(self): |
| # Just for checking that Dynamo returned mod object can redirect |
| # to this method |
| pass |
| |
| def forward(self, x): |
| return self.relu(x) |
| |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| mod = MockModule() |
| optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod) |
| |
| a = torch.randn(10) |
| ref = mod(a) |
| res = optimized_mod(a) |
| |
| optimized_mod.custom_member() |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_nested_optimize_decorator(self): |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| cnts3 = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.run() |
| def fn1(x): |
| return torch.sin(x) * 10 |
| |
| @torch._dynamo.optimize(cnts2, nopython=True) |
| def fn2(x): |
| return fn1(x) + 1 |
| |
| @torch._dynamo.optimize(cnts3, nopython=True) |
| def fn3(x): |
| return torch.relu(fn2(x)) |
| |
| fn3(torch.randn(4, 5)) |
| self.assertEqual(cnts2.frame_count, 0) |
| self.assertEqual(cnts3.frame_count, 1) |
| self.assertEqual(cnts3.op_count, 4) |
| |
| def test_nested_optimize_run(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts, nopython=True) |
| def fn(x): |
| return torch.relu(torch.cos(x) + torch.sin(x)) |
| |
| fn(torch.randn(4)) |
| self.assertEqual(cnts.frame_count, 1) |
| |
| fn(torch.randn(4, 4)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| # Test that run works on a decorated fn |
| fn = torch._dynamo.run(fn) |
| fn(torch.randn(4, 4, 4)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_nested_optimize(self): |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| return torch.relu(torch.cos(x) + torch.sin(x)) |
| |
| fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) |
| fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) |
| |
| # The first optimize in the nesting should be ignored |
| fn2(torch.randn(4)) |
| self.assertEqual(cnts2.frame_count, 1) |
| self.assertEqual(cnts1.frame_count, 0) |
| |
| # Since the fn code object is already compiled, calling fn1 should |
| # directly call the compiled_fn callable. |
| torch._dynamo.run()(fn1)(torch.randn(4)) |
| self.assertEqual(cnts1.frame_count, 0) |
| |
| # Test same behavior by reversing the calls |
| torch._dynamo.reset() |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) |
| fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) |
| fn1(torch.randn(4)) |
| self.assertEqual(cnts1.frame_count, 1) |
| torch._dynamo.run()(fn2)(torch.randn(4)) |
| self.assertEqual(cnts2.frame_count, 0) |
| |
| def test_nested_disable_decorator(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.disable() |
| def fn1(x): |
| return torch.sin(x) * 10 |
| |
| @torch._dynamo.optimize(cnts) |
| def fn2(x): |
| x = x + 1 |
| x = x + 1 |
| x = fn1(x) # graph break |
| x = x + 1 |
| x = x + 1 |
| return x |
| |
| @torch._dynamo.optimize(cnts, nopython=True) |
| def fn3(x): |
| return fn2(x) |
| |
| fn2(torch.randn(4, 5)) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 4) |
| |
| try: |
| fn3(torch.randn(4, 5)) |
| self.assertFalse(True) |
| except torch._dynamo.exc.Unsupported as e: |
| self.assertIn("call torch._dynamo.disable() wrapped function", str(e)) |
| |
| def test_graph_break(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts) |
| def fn(x): |
| x = torch.cos(x) |
| x = torch.cos(x) |
| torch._dynamo.graph_break() |
| x = torch.cos(x) |
| x = torch.cos(x) |
| graph_break() |
| x = torch.cos(x) |
| x = torch.cos(x) |
| return x |
| |
| fn(torch.randn(4, 5)) |
| self.assertEqual(cnts.frame_count, 3) |
| self.assertEqual(cnts.op_count, 6) |
| |
| def test_torch_size(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| output_size = torch.Size([10, 10]) |
| x = x.view(*output_size) |
| return (x,) |
| |
| x = torch.randn(100, requires_grad=True) |
| x_clone = x.clone() |
| ref = fn(x) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn(x_clone) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_torch_seed(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| attention_seed = int(torch.seed() % sys.maxsize) |
| torch.manual_seed(attention_seed) |
| return (x,) |
| |
| x = torch.randn(100, requires_grad=True) |
| ref = fn(x) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn(x) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_is_tensor_like(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def f(x): |
| if torch.overrides.is_tensor_like(x): |
| return (x * 2,) |
| return (torch.ones(10) + x,) |
| |
| x = torch.randn(10) |
| ref0 = f(x) |
| ref1 = f(4) |
| opt_f = torch._dynamo.optimize(cnts, nopython=True)(f) |
| res0 = opt_f(x) |
| res1 = opt_f(4) |
| self.assertTrue(same(ref0, res0)) |
| self.assertTrue(same(ref1, res1)) |
| |
| def test_version_ci(self): |
| # temporary test to check that the ci torch version is set correctly |
| self.assertTrue(hasattr(torch, "_subclasses")) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_rand(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| device = "cuda" |
| |
| def fn(): |
| return torch.randn(10, device=device) |
| |
| torch.manual_seed(10) |
| ref_run1 = fn() |
| |
| torch.manual_seed(10) |
| ref_run2 = fn() |
| self.assertTrue(same(ref_run1, ref_run2)) |
| |
| torch.manual_seed(10) |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn() |
| |
| self.assertTrue(same(res, ref_run1)) |
| |
| def test_slice_input(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def getitem(a, idx): |
| if isinstance(idx, slice): |
| return ( |
| torch.zeros(1), |
| a[idx] |
| + [ |
| 100, |
| ], |
| ) |
| else: |
| return (torch.zeros(1), a[idx]) |
| |
| layers = list(range(10)) |
| ref0 = getitem(layers, slice(0, 2, 1)) |
| ref1 = getitem(layers, 2) |
| ref2 = getitem(layers, slice(3, 8, 2)) |
| opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem) |
| res0 = opt_getitem(layers, slice(0, 2, 1)) |
| res1 = opt_getitem(layers, 2) |
| res2 = opt_getitem(layers, slice(3, 8, 2)) |
| |
| self.assertTrue(ref0 == res0) |
| self.assertTrue(ref1 == res1) |
| self.assertTrue(ref2 == res2) |
| |
| def test_grad(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(a, b): |
| out = a * b |
| out.sum().backward() |
| real_out = torch.sigmoid(a.grad + b) |
| return real_out |
| |
| inps = [torch.randn(4, requires_grad=True) for _ in range(2)] |
| for inp in inps: |
| inp.grad = None |
| ref = fn(*inps) |
| |
| for inp in inps: |
| inp.grad = None |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(*inps) |
| |
| self.assertTrue(same(ref, res)) |
| |
| @unittest.skipIf(sys.version_info < (3, 10), "use linetable when python >= 3.10") |
| def test_linetable_writer(self): |
| def fn(): |
| a = 10 |
| b = 20 |
| c = a + b |
| f = "linetable_writer" |
| return f"Test if {f} generates correct co_linetable: {c}" |
| |
| inst = dis.get_instructions(fn) |
| result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) |
| self.assertTrue(result[1] == fn.__code__.co_linetable) |
| |
| @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") |
| def test_lnotab_writer(self): |
| def fn(): |
| a = 10 |
| b = 20 |
| c = a + b |
| f = "lnotab_writer" |
| return f"Test if {f} generates correct co_lnotab: {c}" |
| |
| inst = dis.get_instructions(fn) |
| result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) |
| self.assertTrue(result[1] == fn.__code__.co_lnotab) |
| |
| def test_torch_profiler(self): |
| # wrap torch.profiler.* as ProfilerContextWrapperVariable and do nothing |
| def fn(x): |
| y = x**2 |
| with torch.profiler.profile(): |
| y = y + 2 |
| with torch.profiler.record_function("my_function"): |
| z = y**3 |
| z.tolist() # graph break |
| z = z + 1 |
| return z |
| |
| x = torch.randn((2, 2), requires_grad=True) |
| ref = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_autograd_profiler(self): |
| # wrap torch.autograd.profiler.* as ProfilerContextWrapperVariable and do nothing |
| def fn(x): |
| y = x**2 |
| with torch.autograd.profiler.profile(): |
| y = y + 2 |
| with torch.autograd.profiler.record_function("my_function"): |
| z = y**3 |
| z.tolist() # graph break |
| z = z + 1 |
| return z |
| |
| x = torch.randn((2, 2), requires_grad=True) |
| ref = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_python_slice(self): |
| def f1(input): |
| y = 0 |
| for i, x in enumerate(input[2:], 1): |
| y = y + x |
| return y |
| |
| def f2(input): |
| y = 0 |
| for i, x in enumerate(input.shape[2:], 1): |
| y = y + x |
| return y |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f1 = torch._dynamo.optimize(cnts)(f1) |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res1 = opt_f1([1, 2, 3, 5]) |
| res2 = opt_f2(torch.rand([2, 3, 4, 5])) |
| |
| self.assertEqual(res1, 8) |
| self.assertEqual(res2, 9) |
| |
| def test_const_dict_variable_python_type(self): |
| from torch._dynamo.variables import ConstDictVariable |
| |
| d1 = {"a": 10, "b": 20} |
| d2 = collections.OrderedDict([("x", 12), ("y", 22)]) |
| self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict) |
| self.assertEqual( |
| ConstDictVariable(d2, collections.OrderedDict).python_type(), |
| collections.OrderedDict, |
| ) |
| |
| def test_builtin_subclasses_as_method_on_class_type(self): |
| class Foo: |
| def __init__(self, name): |
| self.ame_ = name |
| |
| def get_name(self): |
| return "Foo " + self.name_ |
| |
| class Bar(Foo): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Bar " + self.name_ |
| |
| class Baz(Foo): |
| def __init__(self, name): # noqa: B903 |
| self.name_ = name |
| |
| def get_name(self): |
| return "Baz " + self.name_ |
| |
| subs_of_foo_reg = Foo.__subclasses__() |
| |
| counter = CompileCounter() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn(): |
| return Foo.__subclasses__() |
| |
| subs_of_foo_optim = fn() |
| |
| self.assertEqual(len(subs_of_foo_reg), 2) |
| self.assertEqual(subs_of_foo_reg, subs_of_foo_optim) |
| |
| def test_builtin_subclasses_as_method_on_var(self): |
| class Foo: |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Foo " + self.name_ |
| |
| class Bar(Foo): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Bar " + self.name_ |
| |
| class Baz(Bar): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Baz " + self.name_ |
| |
| subs_of_foo_reg = Foo.__subclasses__() |
| sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__() |
| |
| sub_of_foo_subclass_var_optim = list() |
| counter = CompileCounter() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn(): |
| return Foo.__subclasses__() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn_single(subs_of_foo_optim): |
| return subs_of_foo_optim[0].__subclasses__() |
| |
| subs_of_foo_optim = fn() |
| sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim) |
| |
| self.assertEqual(len(sub_of_foo_subclass_var_optim), 1) |
| self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg) |
| |
| def test_enum_no_graphbreaks(self): |
| class Foo(enum.Enum): |
| FOO = 0 |
| BAR = 1 |
| |
| def fn(x, foo): |
| if foo is Foo.FOO: |
| x = torch.add(x, 1.0) |
| x = torch.mul(x, 1.0) |
| return x |
| |
| x = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, Foo.FOO) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, Foo.BAR) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_id_of_nn_module(self): |
| class M(torch.nn.Module): |
| def forward(self, x, ref_id): |
| self_id = id(self) |
| if self_id == ref_id: |
| x = torch.mul(x, 1.0) |
| x = torch.add(x, 1.0) |
| return x |
| |
| m = M().eval() |
| data = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| correct_ref_id = id(m) |
| opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) |
| opt_m(data, correct_ref_id) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| incorrect_ref_id = id(m) + 1 |
| opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) |
| opt_m(data, incorrect_ref_id) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_inline_func_jump_on_tensor_condition(self): |
| def f1(input): |
| if input == 0: |
| return input + 1 |
| else: |
| return input + 2 |
| |
| def f2(input): |
| return f1(input) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res1 = opt_f2(torch.tensor([1.0])) |
| res2 = opt_f2(torch.tensor([0.0])) |
| |
| self.assertEqual(res1, 3) |
| self.assertEqual(res2, 1) |
| |
| def test_frozenset_torch_func_contains(self): |
| funcs = frozenset([torch.add]) |
| |
| def fn(x, func): |
| if func in funcs: |
| x = torch.add(x, 1.0) |
| x = torch.mul(x, 1.0) |
| return x |
| |
| x = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, torch.add) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, torch.mul) |
| self.assertEqual(cnts.op_count, 1) |
| |
| @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) |
| def test_unsupported_fake_tensor(self): |
| def f(x): |
| return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8) |
| |
| x = torch.randn(2, 2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f = torch._dynamo.optimize(cnts)(f) |
| opt_f(x) |
| self.assertEqual(cnts.op_count, 0) |
| |
| torch._dynamo.reset() |
| with patch.object(torch._dynamo.config, "fake_tensor_propagation", False): |
| opt_f = torch._dynamo.optimize_assert( |
| torch._dynamo.testing.CompileCounter() |
| )(f) |
| opt_f(x) |
| |
| def test_inline_list_mutation(self): |
| def f1(x): |
| x.append(torch.ones(8)) |
| return x |
| |
| def f2(): |
| x = [torch.ones(6)] |
| f1(x) |
| return x |
| |
| res1 = f2() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res2 = opt_f2() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_inline_dict_mutation(self): |
| def f1(d): |
| d["c"] = d["a"] + d.pop("b") |
| return d |
| |
| def f2(): |
| d = {"a": torch.ones(5), "b": torch.ones(5)} |
| f1(d) |
| return d |
| |
| res1 = f2() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res2 = opt_f2() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_recursive_inline_list_mutation(self): |
| def f1(x, y): |
| x.append(torch.tensor([1.1])) |
| y.append(torch.tensor([1.2])) |
| return x, y |
| |
| def f2(x, y): |
| x.append(torch.tensor([2.1])) |
| y.append(torch.tensor([2.2])) |
| f1(x, y) |
| return x, y |
| |
| def f3(x): |
| x.append(torch.tensor([3.1])) |
| y = [torch.tensor([3.2])] |
| f2(x, y) |
| return x, y |
| |
| def f4(): |
| x = [torch.tensor([4.1])] |
| return f3(x) |
| |
| res1 = f4() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f4 = torch._dynamo.optimize(cnts)(f4) |
| res2 = opt_f4() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_disallow_in_graph(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts) |
| def fn(a): |
| x = torch.add(a, 1) |
| x = torch.add(x, 1) |
| x = torch.sub(x, 1) |
| x = torch.add(x, 1) |
| x = torch.add(x, 1) |
| return x |
| |
| torch._dynamo.disallow_in_graph(torch.sub) |
| fn(torch.randn(10)) |
| torch._dynamo.allow_in_graph(torch.sub) |
| |
| # check for graph break on sub |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 4) |
| |
| def test_allow_in_graph(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts) |
| def fn(a): |
| x = torch.add(a, 1) |
| x = torch.add(x, 1) |
| x = my_custom_function(x) |
| x = torch.add(x, 1) |
| x = torch.add(x, 1) |
| return x |
| |
| torch._dynamo.allow_in_graph(my_custom_function) |
| fn(torch.randn(10)) |
| torch._dynamo.disallow_in_graph(my_custom_function) |
| |
| # check for no graph break |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 5) |
| |
| def test_sample_input(self): |
| from torch.testing._internal.common_methods_invocations import SampleInput |
| |
| def fn(sample): |
| if isinstance(sample.input, torch.Tensor): |
| return sample.input * 2 |
| return torch.zeros(()) |
| |
| sample = SampleInput(torch.ones(2)) |
| ref = fn(sample) |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(sample) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_release_input_memory(self): |
| x = torch.rand([4]) |
| x_ref = weakref.ref(x) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts) |
| def foo(x): |
| return x + x |
| |
| out = foo(x) |
| self.assertTrue(same(out, x + x)) |
| del x |
| self.assertIs(x_ref(), None) |
| |
| def test_release_module_memory(self): |
| |
| mod = torch.nn.Linear(10, 10) |
| x = torch.rand([10, 10]) |
| mod_weight_ref = weakref.ref(mod.weight) |
| mod_ref = weakref.ref(mod) |
| |
| # Modules that are passed into torch._dynamo optimized functions |
| # will normally be held onto through the generated GraphModule, |
| # which contains the modules. remove the reference in this backend |
| # and test that no additional references are being held. |
| class NoLeakBackend: |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| gm.mod = None |
| |
| def foo(*args, **kwargs): |
| return (1,) |
| |
| return foo |
| |
| no_leak_backend = NoLeakBackend() |
| |
| @torch._dynamo.optimize(no_leak_backend) |
| def foo(mod, x): |
| return mod(x) |
| |
| foo(mod, x) |
| del mod |
| del x |
| self.assertIsNone(mod_ref(), None) |
| self.assertIsNone(mod_weight_ref(), None) |
| |
| def test_update_locals_and_stack_uses_shared_cache(self): |
| def fn(x): |
| perm = [0, 3, 5] |
| perm = list(range(min(perm))) + perm |
| perm.extend(i for i in range(x.dim()) if i not in perm) |
| return perm |
| |
| x = torch.rand([2, 2, 2, 2, 2, 2]) |
| res1 = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res2 = opt_fn(x) |
| self.assertTrue(same(res1, res2)) |
| |
| def test_dict_reconstruct_keeps_original_order(self): |
| def fn(): |
| modules = collections.OrderedDict([("act", torch.nn.ReLU())]) |
| module_dict = torch.nn.ModuleDict(modules) |
| |
| next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} |
| modules.update(next_modules.items()) |
| module_dict.update(next_modules) |
| return modules, module_dict |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| modules, module_dict = opt_fn() |
| |
| self.assertEqual(len(module_dict), len(modules)) |
| for k1, m2 in zip(modules, module_dict.children()): |
| self.assertTrue(modules[k1] is m2) |
| |
| def test_side_effects_codegen_update_mutated(self): |
| # codegen to update mutated variables with side effect |
| # should after stack value's codegen |
| def f1(x): |
| alist = [x] |
| alist.append(x + 1) |
| alist[0].sum().item() # graph break |
| res = alist.pop() |
| res.sum().item() # graph break |
| return res |
| |
| def f2(a, b): |
| d = {"a": a + 1, "b": b + 2} |
| x = d.pop("b") |
| x.sum().item() # graph break |
| y = d["a"] + x |
| y.sum().item() # graph break |
| d["c"] = y |
| return d |
| |
| x = torch.rand([2, 3]) |
| a = torch.rand([5, 6]) |
| b = torch.rand([5, 6]) |
| res11 = f1(x) |
| res21 = f2(a, b) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f1 = torch._dynamo.optimize(cnts)(f1) |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res12 = opt_f1(x) |
| res22 = opt_f2(a, b) |
| self.assertTrue(same(res11, res12)) |
| self.assertTrue(same(res21, res22)) |
| |
| def test_list_append_return_none(self): |
| def fn(x): |
| alist = [] |
| blist = alist.append(x + 1) |
| return alist, blist |
| |
| x = torch.tensor([2.3]) |
| res = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res2 = opt_fn(x) |
| self.assertEqual(res, res2) |
| |
| def test_tensor_types(self): |
| def fn(dtype, tensor_type): |
| x = torch.empty(4, dtype=dtype) |
| assert isinstance(x, tensor_type) |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| opt_fn(torch.float32, torch.FloatTensor) |
| opt_fn(torch.float64, torch.DoubleTensor) |
| opt_fn(torch.float16, torch.HalfTensor) |
| opt_fn(torch.bfloat16, torch.BFloat16Tensor) |
| opt_fn(torch.uint8, torch.ByteTensor) |
| opt_fn(torch.int8, torch.CharTensor) |
| opt_fn(torch.int64, torch.LongTensor) |
| opt_fn(torch.int, torch.IntTensor) |
| opt_fn(torch.int16, torch.ShortTensor) |
| opt_fn(torch.bool, torch.BoolTensor) |
| |
| def test_nan(self): |
| def f(x, n): |
| return x * 2 + n |
| |
| x = torch.randn(4) |
| n = float("nan") |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f = torch._dynamo.optimize(cnts)(f) |
| opt_f(x, n) |
| opt_f(x, n) |
| self.assertEqual(cnts.frame_count, 1) |
| |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| y = torch._dynamo.optimize("eager", nopython=True)(model)(x) |
| |
| self.assertEqual(y, 11) |
| |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item_changes(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| y = opt_model(x) |
| z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]])) |
| |
| self.assertEqual(y, 11) |
| self.assertEqual(z, 61) |
| |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item_changes_new_shape(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| y = opt_model(x) |
| z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]])) |
| |
| self.assertEqual(y, 11) |
| self.assertEqual(z, 61) |
| |
| def test_cross_entropy_loss_fancy_ctor(self): |
| output = None |
| rand_5 = torch.randn(5) |
| rand_3_5 = torch.randn(3, 5) |
| target = torch.empty(3, dtype=torch.long).random_(5) |
| |
| loss = torch.nn.CrossEntropyLoss( |
| weight=rand_5, reduce=False, label_smoothing=0.5 |
| ) |
| opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) |
| input = rand_3_5 |
| dynamo_output = opt_loss(input, target) |
| |
| loss = torch.nn.CrossEntropyLoss( |
| weight=rand_5, reduce=False, label_smoothing=0.5 |
| ) |
| input = rand_3_5 |
| output = loss(input, target) |
| |
| self.assertTrue(torch.allclose(dynamo_output, output)) |
| |
| def test_cross_entropy_loss_simple_ctor(self): |
| output = None |
| rand_3_5 = torch.randn(3, 5) |
| target = torch.empty(3, dtype=torch.long).random_(5) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) |
| input = rand_3_5 |
| dynamo_output = opt_loss(input, target) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| input = rand_3_5 |
| output = loss(input, target) |
| |
| self.assertTrue(torch.allclose(dynamo_output, output)) |
| |
| def test_large_reduction_list(self): |
| dtype = torch.float32 |
| device = "cpu" |
| |
| def check_sum_all(tensor: torch.Tensor) -> None: |
| pylist = tensor.reshape(-1).tolist() |
| self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist)))) |
| |
| check_sum_all(torch.randn(200000, dtype=dtype, device=device)) |
| |
| def test_raise_on_backend_error(self): |
| def my_compiler(gm, _): |
| raise RuntimeError("duck!") |
| |
| @torch._dynamo.optimize(my_compiler) |
| def fn(a, b): |
| return a + b / (a - b) |
| |
| self.assertRaises( |
| torch._dynamo.exc.BackendCompilerFailed, |
| lambda: fn(torch.randn(10), torch.randn(10)), |
| ) |
| |
| def test_named_parameters(self): |
| n_embd = 768 |
| block_size = 128 |
| vocab_size = 65 |
| embd_pdrop = 0.1 |
| |
| class MyModel2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| |
| def forward(self, x): |
| return x |
| |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| self.submod2 = MyModel2() |
| |
| def forward(self, x): |
| return x |
| |
| # Regular |
| params = [] |
| mod = MyModel() |
| actual_params = list(mod.named_parameters()) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn(): |
| return list(mod.named_parameters()) |
| |
| params = fn() |
| |
| self.assertEqual(len(actual_params), len(params)) |
| for idx in range(len(params)): |
| k_a, v_a = actual_params[idx] |
| k, v = params[idx] |
| self.assertEqual(k_a, k) |
| self.assertTrue(torch.allclose(v_a, v)) |
| |
| # Prefix |
| params = [] |
| mod = MyModel() |
| actual_params = list(mod.named_parameters(prefix="foo")) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn1(): |
| return list(mod.named_parameters(prefix="foo")) |
| |
| params = fn1() |
| |
| self.assertEqual(len(actual_params), len(params)) |
| for idx in range(len(params)): |
| k_a, v_a = actual_params[idx] |
| k, v = params[idx] |
| self.assertEqual(k_a, k) |
| self.assertTrue(torch.allclose(v_a, v)) |
| |
| def test_module_complex_iter(self): |
| n_embd = 768 |
| block_size = 128 |
| vocab_size = 65 |
| embd_pdrop = 0.1 |
| |
| class FakeGPT(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| self.ln_f = torch.nn.LayerNorm(n_embd) |
| self.head = torch.nn.Linear(n_embd, vocab_size, bias=False) |
| |
| self.block_size = block_size |
| self.names = [] |
| |
| def forward(self, idx, targets=None): |
| from torch.nn import functional as F |
| |
| b, t = idx.size() |
| assert ( |
| t <= self.block_size |
| ), "Cannot forward, model block size is exhausted." |
| |
| # forward the GPT model |
| token_embeddings = self.tok_emb( |
| idx |
| ) # each index maps to a (learnable) vector |
| position_embeddings = self.pos_emb[ |
| :, :t, : |
| ] # each position maps to a (learnable) vector |
| x = self.drop(token_embeddings + position_embeddings) |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| logits = self.head(x) |
| |
| # if we are given some desired targets also calculate the loss |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), targets.view(-1) |
| ) |
| |
| return logits, loss |
| |
| def foo(self, memo=None, prefix="", remove_duplicate=False): |
| for mn, m in self.named_modules( |
| memo=memo, prefix=prefix, remove_duplicate=remove_duplicate |
| ): |
| for pn, p in self.named_parameters(): |
| fpn = "%s.%s" % (mn, pn) if mn else pn |
| self.names.append(fpn) |
| |
| # Test plain recurse |
| model_a = FakeGPT() |
| model_a.foo() |
| a_names = model_a.names |
| |
| model_b = FakeGPT() |
| opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) |
| opt_model_b.foo() |
| |
| self.assertEqual(a_names, model_b.names) |
| |
| # Test with prefix |
| model_a = FakeGPT() |
| model_a.foo(prefix="abc") |
| a_names = model_a.names |
| |
| model_b = FakeGPT() |
| opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) |
| opt_model_b.foo(prefix="abc") |
| |
| self.assertEqual(a_names, model_b.names) |
| |
| def test_numpy_variable_isinstance(self): |
| def fn(x, m): |
| if isinstance(m, np.ndarray): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.tensor([2.3]) |
| m = np.array([1, 2, 3]) |
| ref = fn(x, m) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(x, m) |
| self.assertEqual(ref, res) |
| |
| def test_tensor_dot_grad_no_graph_break(self): |
| def fn(a, b): |
| y = 3 * a**3 - b**2 |
| y.backward(gradient=torch.tensor([1.0, 1.0])) |
| b.grad.zero_() |
| return a.grad, b.grad |
| |
| a = torch.tensor([2.0, 3.0], requires_grad=True) |
| b = torch.tensor([6.0, 4.0], requires_grad=True) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| _, b_grad = opt_fn(a, b) |
| self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0]))) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_torch_nn_parameter_isinstance(self): |
| def fn(x): |
| a = torch.nn.Parameter(torch.rand(2, 3)) |
| if isinstance(a, torch.Tensor): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.tensor([2.5]) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_change_backends(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn1(): |
| return x + 1 |
| |
| @torch._dynamo.optimize("ts") |
| def fn2(): |
| return x + 2 |
| |
| @torch._dynamo.optimize("eager", nopython=False) |
| def fn3(): |
| return x + 1 |
| |
| x = torch.tensor([3, 5]) |
| |
| fn1() |
| fn1() |
| fn3() |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn2) |
| fn1() |
| torch._dynamo.reset() |
| fn2() |
| fn2() |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn1) |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn3) |
| fn2() |
| |
| def test_dynamo_min_operator_with_shape(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(x, a): |
| return min(x.shape[0], a) |
| |
| result = f(torch.ones(6), 3) |
| self.assertEqual(result, 3) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| def test_onnx_shape_as_tensor(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(x): |
| return 1 + torch._shape_as_tensor(x)[0] |
| |
| gm, _ = torch._dynamo.export(f, torch.ones(6)) |
| |
| input_one_dim = torch.ones(6) |
| input_two_dims = torch.ones(7, 4) |
| self.assertEqual(f(input_one_dim), 7) |
| self.assertEqual(f(input_two_dims), 8) |
| self.assertEqual(f(input_two_dims), 8) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f_onnx(x): |
| return 1 + torch.onnx.operators.shape_as_tensor(x)[0] |
| |
| self.assertEqual(f_onnx(input_one_dim), 7) |
| self.assertEqual(f_onnx(input_two_dims), 8) |
| self.assertEqual(f_onnx(input_two_dims), 8) |
| |
| def test_cond(self): |
| from functorch.experimental.cond import cond |
| |
| def true_fn(x): |
| return x.sin() |
| |
| def false_fn(x): |
| return x.cos() |
| |
| def f(pred, x): |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| opt_fn = torch._dynamo.optimize("eager")(f) |
| a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a)) |
| b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) |
| |
| def test_cond_nested(self): |
| from functorch.experimental.cond import cond |
| |
| def true_fn_nested(x): |
| return x * 10 |
| |
| def false_fn_nested(x): |
| return x * -1 |
| |
| def true_fn(pred2, x): |
| return x.sin() |
| |
| def false_fn(pred2, x): |
| return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) |
| |
| def f(pred, pred2, x): |
| return cond(pred, true_fn, false_fn, [pred2, x]) |
| |
| cc = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cc)(f) |
| true_true_sin = opt_fn( |
| torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) |
| |
| true_false_sin = opt_fn( |
| torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) |
| |
| false_true_sum_mult = opt_fn( |
| torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([2.75, 2.75]), false_true_sum_mult) |
| ) # * 10 then add x |
| |
| false_false_sum_neg = opt_fn( |
| torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([0.0, 0.0]), false_false_sum_neg) |
| ) # * -1 then add x |
| self.assertTrue(cc.frame_count, 2) |
| |
| @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) |
| def test_cond_nested_fake_tensor_off(self): |
| from functorch.experimental.cond import cond |
| |
| def true_fn_nested(x): |
| return x * 10 |
| |
| def false_fn_nested(x): |
| return x * -1 |
| |
| def true_fn(pred2, x): |
| return x.sin() |
| |
| def false_fn(pred2, x): |
| return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) |
| |
| def f(pred, pred2, x): |
| return cond(pred, true_fn, false_fn, [pred2, x]) |
| |
| cc = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cc)(f) |
| true_true_sin = opt_fn( |
| torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) |
| |
| true_false_sin = opt_fn( |
| torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) |
| |
| false_true_sum_mult = opt_fn( |
| torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([2.75, 2.75]), false_true_sum_mult) |
| ) # * 10 then add x |
| |
| false_false_sum_neg = opt_fn( |
| torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([0.0, 0.0]), false_false_sum_neg) |
| ) # * -1 then add x |
| self.assertTrue(cc.frame_count, 1) |
| |
| @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) |
| def test_cond_export(self): |
| from functorch.experimental.cond import cond |
| |
| def true_fn_nested(x): |
| return x * 10 |
| |
| def false_fn_nested(x): |
| return x * -1 |
| |
| def true_fn(pred2, x): |
| return x.sin() |
| |
| def false_fn(pred2, x): |
| return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) |
| |
| def f(pred, pred2, x): |
| return cond(pred, true_fn, false_fn, [pred2, x]) |
| |
| graph, guard = torch._dynamo.export( |
| f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| true_true_sin = graph( |
| torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) |
| |
| true_false_sin = graph( |
| torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) |
| |
| false_true_sum_mult = graph( |
| torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([2.75, 2.75]), false_true_sum_mult) |
| ) # * 10 then add x |
| |
| false_false_sum_neg = graph( |
| torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([0.0, 0.0]), false_false_sum_neg) |
| ) # * -1 then add x |
| |
| @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) |
| def test_cond_export_single_arg(self): |
| from functorch.experimental.cond import cond |
| |
| def true_fn(x): |
| return x |
| |
| def false_fn(x): |
| return x.sin() |
| |
| def f(pred, x): |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| graph, guard = torch._dynamo.export( |
| f, torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) |
| true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33])) |
| self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2)) |
| |
| false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5])) |
| self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin)) |
| |
| def test_disable_optimize(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt, disable=True) |
| def f1(x): |
| return x + 1 |
| |
| f1(torch.ones(6)) |
| self.assertEqual(cnt.frame_count, 0) |
| |
| @torch._dynamo.optimize(cnt, disable=True) |
| def f2(x): |
| return x + 1 |
| |
| f2(torch.ones(6)) |
| self.assertEqual(cnt.frame_count, 0) |
| |
| with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}): |
| |
| @torch._dynamo.optimize(cnt) |
| def f3(x): |
| return x + 1 |
| |
| f3(torch.ones(6)) |
| self.assertEqual(cnt.frame_count, 0) |
| |
| def test_config_log_level(self): |
| @torch._dynamo.optimize("eager") |
| def fn(a, b): |
| return a + b |
| |
| with self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log: |
| torch._dynamo.config.log_level = logging.DEBUG |
| fn(torch.randn(10), torch.randn(10)) |
| cur_len = len(log) |
| self.assertGreater(cur_len, 0) |
| |
| torch._dynamo.config.log_level = logging.WARNING |
| fn(torch.randn(10), torch.randn(10)) |
| self.assertEqual(cur_len, len(log)) |
| |
| @patch.object(torch._dynamo.config, "print_graph_breaks", True) |
| def test_duplicate_graph_break_warning(self): |
| @torch._dynamo.optimize("eager") |
| def f1(a, b): |
| f2(a, b) |
| |
| def f2(a, b): |
| c = a + b |
| print("break") |
| return a + b + c |
| |
| @torch._dynamo.optimize("eager") |
| def g1(a, b): |
| g2(a, b) |
| |
| def g2(a, b): |
| c = a + b |
| print("break") |
| return a + b + c |
| |
| def count_graph_break_msgs(msgs): |
| return sum(msg.find("Graph break") != -1 for msg in msgs) |
| |
| with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: |
| torch._dynamo.config.verbose = True |
| f1(torch.randn(10), torch.randn(10)) |
| self.assertGreater(count_graph_break_msgs(log.output), 1) |
| |
| with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: |
| torch._dynamo.config.verbose = False |
| g1(torch.randn(10), torch.randn(10)) |
| self.assertEqual(count_graph_break_msgs(log.output), 1) |
| |
| def test_inplace_param_update(self): |
| def fn(param, y): |
| prev_grad = torch.is_grad_enabled() |
| try: |
| torch.set_grad_enabled(False) |
| torch.set_grad_enabled(True) |
| torch.set_grad_enabled(False) |
| param.add_(y) |
| finally: |
| torch.set_grad_enabled(prev_grad) |
| |
| y = torch.randn(4) |
| x = torch.nn.Parameter(torch.randn(4)) |
| fn(x, y) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, y) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 5) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_autocast(self): |
| if not torch.cuda.is_bf16_supported(): |
| raise unittest.SkipTest("requires bf16") |
| |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| a_float32 = torch.rand((8, 8), device="cuda") |
| b_float32 = torch.rand((8, 8), device="cuda") |
| d_float32 = torch.rand((8, 8), device="cuda") |
| |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| e_float16 = torch.mm(a_float32, b_float32) |
| f_float16 = torch.mm(d_float32, e_float16) |
| return f_float16 |
| |
| module = MyModule() |
| real = module(torch.tensor([0.5])) |
| real_device = real.device |
| real_dtype = real.dtype |
| |
| graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) |
| exported = graph(torch.tensor([0.5])) |
| self.assertEqual(exported.device, real_device) |
| self.assertEqual(exported.dtype, real_dtype) |
| |
| self.assertEqual(exported.device.type, "cuda") |
| self.assertEqual(exported.device.index, 0) |
| self.assertEqual(exported.dtype, torch.bfloat16) |
| |
| def test_autocast_cpu(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| a_float32 = torch.rand((8, 8), device="cpu") |
| b_float32 = torch.rand((8, 8), device="cpu") |
| d_float32 = torch.rand((8, 8), device="cpu") |
| |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): |
| e_float16 = torch.mm(a_float32, b_float32) |
| f_float16 = torch.mm(d_float32, e_float16) |
| return f_float16 |
| |
| module = MyModule() |
| real = module(torch.tensor([0.5])) |
| real_device = real.device |
| real_dtype = real.dtype |
| |
| graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) |
| exported = graph(torch.tensor([0.5])) |
| self.assertEqual(exported.device, real_device) |
| self.assertEqual(exported.dtype, real_dtype) |
| |
| self.assertEqual(exported.device.type, "cpu") |
| self.assertEqual(exported.dtype, torch.bfloat16) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_autocast_float64(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| a_float32 = torch.rand((8, 8), device="cuda") |
| b_float32 = torch.rand((8, 8), device="cuda") |
| d_float32 = torch.rand((8, 8), device="cuda") |
| |
| with torch.autocast(device_type="cuda", dtype=torch.float64): |
| e_float64 = torch.mm(a_float32, b_float32) |
| f_float64 = torch.mm(d_float32, e_float64) |
| return f_float64 |
| |
| module = MyModule() |
| real = module(torch.tensor([0.5])) |
| real_device = real.device |
| real_dtype = real.dtype |
| |
| graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) |
| exported = graph(torch.tensor([0.5])) |
| self.assertEqual(exported.device, real_device) |
| self.assertEqual(exported.dtype, real_dtype) |
| |
| self.assertEqual(exported.device.index, 0) |
| self.assertEqual(exported.dtype, torch.float64) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_autocast_device(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| a_float32 = torch.rand((8, 8), device="cuda") |
| b_float32 = torch.rand((8, 8), device="cuda") |
| d_float32 = torch.rand((8, 8), device="cuda") |
| |
| with torch.autocast(device_type="cuda"): |
| e_float64 = torch.mm(a_float32, b_float32) |
| f_float64 = torch.mm(d_float32, e_float64) |
| return f_float64 |
| |
| module = MyModule() |
| real = module(torch.tensor([0.5])) |
| real_device = real.device |
| real_dtype = real.dtype |
| |
| graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) |
| exported = graph(torch.tensor([0.5])) |
| self.assertEqual(exported.device, real_device) |
| self.assertEqual(exported.dtype, real_dtype) |
| |
| self.assertEqual(exported.device.index, 0) |
| self.assertEqual(exported.dtype, torch.torch.float16) |
| |
| def test_generate_tensor_from_list_of_numpy_primitive_type(self): |
| # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) |
| def fn(): |
| x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) |
| y = [x[0], x[2], x[4]] |
| z = torch.LongTensor(y) |
| return z |
| |
| ref = fn() |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn() |
| self.assertTrue(same(ref, res)) |
| |
| def test_autograd_function_equivalence(self): |
| m1 = Module1() |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f1(): |
| return m1(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f1(), torch.tensor([2.0]))) |
| |
| m2 = Module2() |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f2(): |
| return m2(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f2(), torch.tensor([2.0]))) |
| |
| def test_object_classmethod(self): |
| class C: |
| @classmethod |
| def fn(cls, x): |
| return x + x |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(): |
| return C().fn(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) |
| |
| def test_object_staticmethod(self): |
| class C: |
| @staticmethod |
| def fn(x): |
| return x + x |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(): |
| return C().fn(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) |
| |
| def test_user_function_variable_supports_enum_argument(self): |
| class Foo(enum.Enum): |
| FOO = 0 |
| BAR = 1 |
| |
| def gn(x, y=Foo.FOO): |
| if y is Foo.FOO: |
| return x |
| else: |
| return x + 1 |
| |
| def fn(x): |
| return gn(x) |
| |
| x = torch.randn(2, 3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_repro_graph_breaks_in__get_item_by_idx(self): |
| class Mod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = torch.nn.Sequential( |
| torch.nn.Linear(3, 3), torch.nn.Linear(3, 3) |
| ) |
| |
| def forward(self, x): |
| return self.mod[0](x) |
| |
| m = Mod() |
| graph, _ = torch._dynamo.export(m, torch.randn(3, 3)) |
| |
| def test_nn_sequential_invocation(self): |
| with freeze_rng_state(): |
| |
| class TestModel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linears = torch.nn.Sequential( |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| ) |
| |
| def forward(self, x): |
| all_but_last = self.linears[:-1] |
| return all_but_last(x) |
| |
| m = TestModel() |
| x = torch.rand((2, 2)) |
| real = m(x) |
| graph, _ = torch._dynamo.export(m, x) |
| dynamo_result = graph(x) |
| self.assertTrue(same(real, dynamo_result)) |
| |
| def test_nn_sequential_invocation_reposition_indices(self): |
| with freeze_rng_state(): |
| |
| class TestModel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linears = torch.nn.Sequential( |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| ) |
| |
| def forward(self, x): |
| all_but_last = self.linears[1:3] |
| return all_but_last(x) |
| |
| m = TestModel() |
| x = torch.rand((2, 2)) |
| real = m(x) |
| graph, _ = torch._dynamo.export(m, x) |
| dynamo_result = graph(x) |
| self.assertTrue(same(real, dynamo_result)) |
| |
| |
| class CustomFunc(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, foo): |
| return foo + foo |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output |
| |
| |
| class Module1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, foo): |
| return CustomFunc().apply(foo) |
| |
| |
| class Module2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fn = CustomFunc.apply |
| |
| def forward(self, foo): |
| return self.fn(foo) |
| |
| |
| class TestTracer(JitTestCase): |
| def test_jit_save(self): |
| def fn(): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super(Foo, self).__init__() |
| self.a = 3 |
| |
| @torch.jit.export |
| def __getstate__(self): |
| return (3, self.training) |
| |
| @torch.jit.export |
| def __setstate__(self, state): |
| self.a = state[0] |
| self.training = state[1] |
| |
| def forward(self, x): |
| return x + self.a |
| |
| f = Foo() |
| |
| return torch.jit.trace(f, (torch.rand(3, 4),)) |
| |
| fn() |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| opt_fn() |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |