[Dynamo] Fix nested function resume execution (#100426)

Fixes #99665

Let me explain the root cause using the unit test I added:
* This bug is triggered when:
  * ```wrapped``` is a nested function.
  * ```wrapped``` is in another module which is different from the main function ```fn```.
  * There is a graph break inside of ```wrapped```.
* The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426
Approved by: https://github.com/jansel
diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh
index 4d34f5b..69ba42e 100755
--- a/.ci/pytorch/test.sh
+++ b/.ci/pytorch/test.sh
@@ -233,6 +233,7 @@
     --exclude-distributed-tests \
     --exclude \
       test_autograd \
+      test_jit \
       test_proxy_tensor \
       test_quantization \
       test_public_bindings \
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 811bc48..a20962e 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -5157,6 +5157,26 @@
         self.assertTrue(isinstance(compile_out, torch.Size))
         self.assertEqual(eager_out, compile_out)
 
+    def test_nested_function_resuming_with_correct_globals(self):
+        # https://github.com/pytorch/pytorch/issues/99665
+        try:
+            from .utils import outer_func
+        except ImportError:
+            from utils import outer_func
+
+        def gn(x, y):
+            return x + y
+
+        def fn(x, y):
+            return outer_func(gn)(x, y)
+
+        x = torch.rand([3])
+        y = torch.rand([3])
+        opt_fn = torch.compile(backend="eager")(fn)
+        ref = fn(x, y)
+        res = opt_fn(x, y)
+        self.assertTrue(same(ref, res))
+
 
 class CustomFunc1(torch.autograd.Function):
     @staticmethod
diff --git a/test/dynamo/utils.py b/test/dynamo/utils.py
new file mode 100644
index 0000000..54cacd0
--- /dev/null
+++ b/test/dynamo/utils.py
@@ -0,0 +1,17 @@
+# Owner(s): ["module: dynamo"]
+
+import torch
+import torch._dynamo
+
+
+def inner_func():
+    return torch.is_grad_enabled()
+
+
+def outer_func(func):
+    def wrapped(*args):
+        a = func(*args)
+        torch._dynamo.graph_break()
+        return torch.sin(a + 1), inner_func()
+
+    return wrapped
diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py
index a88b266..5df1918 100644
--- a/test/jit/test_tracer.py
+++ b/test/jit/test_tracer.py
@@ -35,6 +35,7 @@
                        "\tpython test/test_jit.py TESTNAME\n\n"
                        "instead.")
 
+@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
 class TestTracer(JitTestCase):
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     def test_large_nbr_kernel_args(self):
@@ -1990,6 +1991,7 @@
         self.assertEqual(model(**input_dict), traced_model(**input_dict))
 
 
+@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
 class TestMixTracingScripting(JitTestCase):
     def test_trace_script(self):
         @torch.jit.script
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 28ca5aa..ec21547 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -27,6 +27,7 @@
     numpy_to_torch_dtype_dict,
     TEST_SCIPY,
     set_default_dtype,
+    skipIfTorchDynamo,
 )
 from torch.testing._internal.common_device_type import (
     expectedFailureMeta,
@@ -1852,6 +1853,7 @@
             _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
 
     @onlyNativeDeviceTypes
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_div_and_floordiv_script_vs_python(self, device):
         # Creates jitted functions of two tensors
         def _wrapped_div(a, b):
@@ -1924,6 +1926,7 @@
                     self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
 
     @onlyNativeDeviceTypes
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_idiv_and_ifloordiv_vs_python(self, device):
         def _wrapped_idiv_tensor(a, b):
             a /= b
diff --git a/test/test_indexing.py b/test/test_indexing.py
index 38bddda..551327c 100644
--- a/test/test_indexing.py
+++ b/test/test_indexing.py
@@ -12,7 +12,7 @@
 
 from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
-    TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
+    TestCase, run_tests, skipIfTorchDynamo)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
     onlyNativeDeviceTypes, skipXLA)
@@ -738,10 +738,7 @@
             self.assertEqual(y, torch.ones(size=(10, 10), device=device))
             self.assertEqual(len(w), 2)
 
-    @unittest.skipIf(
-        TEST_WITH_TORCHDYNAMO,
-        "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
-    )
+    @skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472")
     def test_index_put_accumulate_large_tensor(self, device):
         # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
         N = (1 << 31) + 5
@@ -839,6 +836,7 @@
         self.assertEqual(out_cuda.cpu(), out_cpu)
 
     @onlyCUDA
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_index_put_accumulate_with_optional_tensors(self, device):
         # TODO: replace with a better solution.
         # Currently, here using torchscript to put None into indices.
@@ -935,6 +933,7 @@
         r = v[c > 0]
         self.assertEqual(r.shape, (num_ones, 3))
 
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_jit_indexing(self, device):
         def fn1(x):
             x[x < 50] = 1.0
diff --git a/test/test_native_functions.py b/test/test_native_functions.py
index ba7889e..c95b4a2 100644
--- a/test/test_native_functions.py
+++ b/test/test_native_functions.py
@@ -2,7 +2,7 @@
 
 from typing import Optional, List
 import torch
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
 
 # End-to-end tests of features in native_functions.yaml
 
@@ -81,6 +81,7 @@
             return torch._C._nn._test_optional_floatlist(values, const)
         return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
 
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_optional_floatlist(self):
         self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
         self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
@@ -134,6 +135,7 @@
             return torch._C._nn._test_optional_intlist(values, const)
         return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
 
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_optional_intlist(self):
         self.do_test_optional_intlist_with_module(IntListWrapperModule())
         self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
@@ -187,6 +189,7 @@
             return torch._C._nn._test_optional_filled_intlist(values, const)
         return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
 
+    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
     def test_optional_filled_intlist(self):
 
         def f(n: int):
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
index 9af810f..280f467 100644
--- a/torch/_dynamo/variables/functions.py
+++ b/torch/_dynamo/variables/functions.py
@@ -3,7 +3,6 @@
 import functools
 import inspect
 import itertools
-import sys
 import types
 from typing import Dict, List
 
@@ -11,11 +10,7 @@
 
 from .. import variables
 from ..allowed_functions import is_allowed, is_builtin_callable
-from ..bytecode_transformation import (
-    create_call_function,
-    create_instruction,
-    create_rot_n,
-)
+from ..bytecode_transformation import create_call_function, create_rot_n
 from ..exc import unimplemented
 from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
 from ..utils import istensor, istype, make_cell
@@ -89,6 +84,26 @@
     return closure_cells
 
 
+def _create_nested_fn(
+    code, f_globals, name, defaults, closure, kwdefaults, annotations
+):
+    from types import FunctionType
+
+    func = FunctionType(code, f_globals, name, defaults, closure)
+    func.__kwdefaults__ = kwdefaults
+
+    if isinstance(annotations, tuple):
+        from itertools import pairwise
+
+        annotations = dict(pairwise(annotations))
+
+    # TypeError: __annotations__ must be set to a dict object
+    assert annotations is None or isinstance(annotations, dict)
+    func.__annotations__ = annotations
+
+    return func
+
+
 class BaseUserFunctionVariable(VariableTracker):
     def get_filename(self):
         return self.get_code().co_filename
@@ -460,17 +475,27 @@
                 parent.symbolic_locals[var] = child.symbolic_locals[var]
 
     def reconstruct(self, codegen):
-        flags = 0x00
+        codegen.load_import_from(__name__, "_create_nested_fn")
+        codegen(self.code)
+        codegen.extend_output([codegen._create_load_const(self.f_globals)])
+        codegen(self.fn_name)
+
         if self.defaults:
-            flags |= 0x01
             codegen(self.defaults)
+        else:
+            codegen.extend_output([codegen.create_load_const(None)])
+
+        if self.closure:
+            codegen(self.closure)
+        else:
+            codegen.extend_output([codegen.create_load_const(None)])
+
         if self.kwdefaults:
-            flags |= 0x02
             codegen(self.kwdefaults)
-        if isinstance(
-            self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
-        ):
-            flags |= 0x04
+        else:
+            codegen.extend_output([codegen.create_load_const(None)])
+
+        if self.annotations:
             try:
                 if isinstance(self.annotations, variables.ConstDictVariable):
                     annotations = {
@@ -484,13 +509,10 @@
                 codegen.extend_output([codegen._create_load_const(annotations)])
             except NotImplementedError:
                 codegen(self.annotations)
-        if self.closure:
-            flags |= 0x08
-            codegen(self.closure)
-        codegen(self.code)
-        if sys.version_info < (3, 11):
-            codegen(self.fn_name)
-        codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)])
+        else:
+            codegen.extend_output([codegen.create_load_const(None)])
+
+        codegen.extend_output(create_call_function(7, push_null=True))
 
         if self.wraps_source:
             codegen.load_import_from("functools", "wraps")