[Dynamo] Fix nested function resume execution (#100426)
Fixes #99665
Let me explain the root cause using the unit test I added:
* This bug is triggered when:
* ```wrapped``` is a nested function.
* ```wrapped``` is in another module which is different from the main function ```fn```.
* There is a graph break inside of ```wrapped```.
* The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426
Approved by: https://github.com/jansel
diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh
index 4d34f5b..69ba42e 100755
--- a/.ci/pytorch/test.sh
+++ b/.ci/pytorch/test.sh
@@ -233,6 +233,7 @@
--exclude-distributed-tests \
--exclude \
test_autograd \
+ test_jit \
test_proxy_tensor \
test_quantization \
test_public_bindings \
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 811bc48..a20962e 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -5157,6 +5157,26 @@
self.assertTrue(isinstance(compile_out, torch.Size))
self.assertEqual(eager_out, compile_out)
+ def test_nested_function_resuming_with_correct_globals(self):
+ # https://github.com/pytorch/pytorch/issues/99665
+ try:
+ from .utils import outer_func
+ except ImportError:
+ from utils import outer_func
+
+ def gn(x, y):
+ return x + y
+
+ def fn(x, y):
+ return outer_func(gn)(x, y)
+
+ x = torch.rand([3])
+ y = torch.rand([3])
+ opt_fn = torch.compile(backend="eager")(fn)
+ ref = fn(x, y)
+ res = opt_fn(x, y)
+ self.assertTrue(same(ref, res))
+
class CustomFunc1(torch.autograd.Function):
@staticmethod
diff --git a/test/dynamo/utils.py b/test/dynamo/utils.py
new file mode 100644
index 0000000..54cacd0
--- /dev/null
+++ b/test/dynamo/utils.py
@@ -0,0 +1,17 @@
+# Owner(s): ["module: dynamo"]
+
+import torch
+import torch._dynamo
+
+
+def inner_func():
+ return torch.is_grad_enabled()
+
+
+def outer_func(func):
+ def wrapped(*args):
+ a = func(*args)
+ torch._dynamo.graph_break()
+ return torch.sin(a + 1), inner_func()
+
+ return wrapped
diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py
index a88b266..5df1918 100644
--- a/test/jit/test_tracer.py
+++ b/test/jit/test_tracer.py
@@ -35,6 +35,7 @@
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
+@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_large_nbr_kernel_args(self):
@@ -1990,6 +1991,7 @@
self.assertEqual(model(**input_dict), traced_model(**input_dict))
+@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
@torch.jit.script
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 28ca5aa..ec21547 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -27,6 +27,7 @@
numpy_to_torch_dtype_dict,
TEST_SCIPY,
set_default_dtype,
+ skipIfTorchDynamo,
)
from torch.testing._internal.common_device_type import (
expectedFailureMeta,
@@ -1852,6 +1853,7 @@
_scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
@onlyNativeDeviceTypes
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_div_and_floordiv_script_vs_python(self, device):
# Creates jitted functions of two tensors
def _wrapped_div(a, b):
@@ -1924,6 +1926,7 @@
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
@onlyNativeDeviceTypes
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_idiv_and_ifloordiv_vs_python(self, device):
def _wrapped_idiv_tensor(a, b):
a /= b
diff --git a/test/test_indexing.py b/test/test_indexing.py
index 38bddda..551327c 100644
--- a/test/test_indexing.py
+++ b/test/test_indexing.py
@@ -12,7 +12,7 @@
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
- TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
+ TestCase, run_tests, skipIfTorchDynamo)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
onlyNativeDeviceTypes, skipXLA)
@@ -738,10 +738,7 @@
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
self.assertEqual(len(w), 2)
- @unittest.skipIf(
- TEST_WITH_TORCHDYNAMO,
- "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
- )
+ @skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472")
def test_index_put_accumulate_large_tensor(self, device):
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
N = (1 << 31) + 5
@@ -839,6 +836,7 @@
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_index_put_accumulate_with_optional_tensors(self, device):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
@@ -935,6 +933,7 @@
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_jit_indexing(self, device):
def fn1(x):
x[x < 50] = 1.0
diff --git a/test/test_native_functions.py b/test/test_native_functions.py
index ba7889e..c95b4a2 100644
--- a/test/test_native_functions.py
+++ b/test/test_native_functions.py
@@ -2,7 +2,7 @@
from typing import Optional, List
import torch
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
# End-to-end tests of features in native_functions.yaml
@@ -81,6 +81,7 @@
return torch._C._nn._test_optional_floatlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_floatlist(self):
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
@@ -134,6 +135,7 @@
return torch._C._nn._test_optional_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_intlist(self):
self.do_test_optional_intlist_with_module(IntListWrapperModule())
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
@@ -187,6 +189,7 @@
return torch._C._nn._test_optional_filled_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_filled_intlist(self):
def f(n: int):
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
index 9af810f..280f467 100644
--- a/torch/_dynamo/variables/functions.py
+++ b/torch/_dynamo/variables/functions.py
@@ -3,7 +3,6 @@
import functools
import inspect
import itertools
-import sys
import types
from typing import Dict, List
@@ -11,11 +10,7 @@
from .. import variables
from ..allowed_functions import is_allowed, is_builtin_callable
-from ..bytecode_transformation import (
- create_call_function,
- create_instruction,
- create_rot_n,
-)
+from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import istensor, istype, make_cell
@@ -89,6 +84,26 @@
return closure_cells
+def _create_nested_fn(
+ code, f_globals, name, defaults, closure, kwdefaults, annotations
+):
+ from types import FunctionType
+
+ func = FunctionType(code, f_globals, name, defaults, closure)
+ func.__kwdefaults__ = kwdefaults
+
+ if isinstance(annotations, tuple):
+ from itertools import pairwise
+
+ annotations = dict(pairwise(annotations))
+
+ # TypeError: __annotations__ must be set to a dict object
+ assert annotations is None or isinstance(annotations, dict)
+ func.__annotations__ = annotations
+
+ return func
+
+
class BaseUserFunctionVariable(VariableTracker):
def get_filename(self):
return self.get_code().co_filename
@@ -460,17 +475,27 @@
parent.symbolic_locals[var] = child.symbolic_locals[var]
def reconstruct(self, codegen):
- flags = 0x00
+ codegen.load_import_from(__name__, "_create_nested_fn")
+ codegen(self.code)
+ codegen.extend_output([codegen._create_load_const(self.f_globals)])
+ codegen(self.fn_name)
+
if self.defaults:
- flags |= 0x01
codegen(self.defaults)
+ else:
+ codegen.extend_output([codegen.create_load_const(None)])
+
+ if self.closure:
+ codegen(self.closure)
+ else:
+ codegen.extend_output([codegen.create_load_const(None)])
+
if self.kwdefaults:
- flags |= 0x02
codegen(self.kwdefaults)
- if isinstance(
- self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
- ):
- flags |= 0x04
+ else:
+ codegen.extend_output([codegen.create_load_const(None)])
+
+ if self.annotations:
try:
if isinstance(self.annotations, variables.ConstDictVariable):
annotations = {
@@ -484,13 +509,10 @@
codegen.extend_output([codegen._create_load_const(annotations)])
except NotImplementedError:
codegen(self.annotations)
- if self.closure:
- flags |= 0x08
- codegen(self.closure)
- codegen(self.code)
- if sys.version_info < (3, 11):
- codegen(self.fn_name)
- codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)])
+ else:
+ codegen.extend_output([codegen.create_load_const(None)])
+
+ codegen.extend_output(create_call_function(7, push_null=True))
if self.wraps_source:
codegen.load_import_from("functools", "wraps")