[dynamo 3.11] support prefix instructions MAKE_CELL, COPY_FREE_VARS, RETURN_GENERATOR, RESUME (#96506)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96506
Approved by: https://github.com/jansel
diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py
index 38700c2..efb7116 100644
--- a/torch/_dynamo/bytecode_analysis.py
+++ b/torch/_dynamo/bytecode_analysis.py
@@ -111,6 +111,8 @@
                         state.reads.add(inst.argval)
                 elif "STORE" in inst.opname:
                     state.writes.add(inst.argval)
+                elif inst.opname == "MAKE_CELL":
+                    pass
                 else:
                     raise NotImplementedError(f"unhandled {inst.opname}")
             if inst.opcode in JUMP_OPCODES:
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index dbc217a..e6e08b5 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -2,6 +2,7 @@
 
 import contextlib
 import dataclasses
+import dis
 import functools
 import inspect
 import logging
@@ -341,11 +342,21 @@
         super().__init__(callback=None)
 
 
+def first_real_inst_idx(code):
+    if sys.version_info < (3, 11):
+        return 0
+    for inst in dis.get_instructions(code):
+        if inst.opname == "RESUME":
+            return inst.offset // 2
+    raise RuntimeError("RESUME instruction not found in code")
+
+
 def catch_errors_wrapper(callback, hooks: Hooks):
     @functools.wraps(callback)
     def catch_errors(frame, cache_size):
         if (
-            frame.f_lasti >= 0
+            # TODO: the first condition is not covered by any test
+            frame.f_lasti >= first_real_inst_idx(frame.f_code)
             or skipfiles.check(frame.f_code.co_filename)
             or config.disable
         ):
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 342fa4b1..9cef09d 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -5,6 +5,7 @@
 import logging
 import operator
 import re
+import sys
 import traceback
 from dataclasses import dataclass
 from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
@@ -524,6 +525,27 @@
         if not all(block.can_restore() for block in tx.block_stack):
             unimplemented("compile_subgraph with block_depth != 0")
 
+        prefix_insts: List[Instruction] = []
+        if sys.version_info >= (3, 11):
+            # prefix instructions (Python 3.11+)
+            for inst in tx.prefix_insts:
+                if inst.opname == "MAKE_CELL":
+                    prefix_insts.append(
+                        create_instruction("MAKE_CELL", argval=inst.argval)
+                    )
+                elif inst.opname == "COPY_FREE_VARS":
+                    prefix_insts.append(
+                        create_instruction(
+                            "COPY_FREE_VARS", len(tx.code_options["co_freevars"])
+                        )
+                    )
+                else:
+                    prefix_insts.append(inst)
+
+        def append_prefix_insts():
+            self.add_output_instructions(prefix_insts)
+            prefix_insts.clear()
+
         for block in reversed(tx.block_stack):
             block.exit(tx)
 
@@ -551,6 +573,7 @@
 
         # to handle random calls
         if len(tx.random_calls) > 0:
+            append_prefix_insts()
             random_calls_instructions = []
             self.random_values_var = self.new_var("random_values")
             rand_fn_name = unique_id("__gen_rand_values")
@@ -583,6 +606,7 @@
             and len(set(stack_values)) == len(stack_values)
             and self.side_effects.is_empty()
         ):
+            append_prefix_insts()
             # optimization to generate better code in a common case
             self.add_output_instructions(
                 self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
@@ -616,6 +640,7 @@
                     output.append(pass2.create_store(graph_output_var))
                 else:
                     output.append(create_instruction("POP_TOP"))
+            append_prefix_insts()
             self.add_output_instructions(output + pass2.get_instructions())
 
         # restore all the live local vars
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index 14f051b..e51c0b6 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -265,6 +265,11 @@
             (target,) = [i for i in instructions if i.offset == offset]
 
             prefix = []
+            if sys.version_info >= (3, 11):
+                if freevars:
+                    prefix.append(create_instruction("COPY_FREE_VARS", len(freevars)))
+                prefix.append(create_instruction("RESUME", 0))
+
             cleanup = []
             hooks = {fn.stack_index: fn for fn in setup_fns}
             null_idxes_i = 0
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 313aacb..f533214 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -420,6 +420,8 @@
     lineno: int
     mutated_closure_cell_contents: Set[str]
     kw_names: Optional[ConstantVariable]
+    accept_prefix_inst: bool
+    prefix_insts: List[Instruction]
 
     checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
     random_calls: List[
@@ -1502,9 +1504,12 @@
     INPLACE_OR = stack_op(operator.ior)
 
     # 3.11 opcodes
-    # note: passed opcodes are intentional
     def RESUME(self, inst):
-        pass
+        if inst.arg == 0:
+            self.append_prefix_inst(inst)
+            self.accept_prefix_inst = False
+        else:
+            assert not self.accept_prefix_inst
 
     def BINARY_OP(self, inst):
         if sys.version_info >= (3, 11):
@@ -1600,6 +1605,19 @@
         self.push(exit)
         self.push(ctx.enter(self))
 
+    def append_prefix_inst(self, inst):
+        assert self.accept_prefix_inst
+        self.prefix_insts.append(inst)
+
+    def MAKE_CELL(self, inst):
+        self.append_prefix_inst(inst)
+
+    def COPY_FREE_VARS(self, inst):
+        self.append_prefix_inst(inst)
+
+    def RETURN_GENERATOR(self, inst):
+        self.append_prefix_inst(inst)
+
     def copy_graphstate(self) -> InstructionTranslatorGraphState:
         """Create a checkpoint of the current state by copying everything"""
         return InstructionTranslatorGraphState(
@@ -1699,6 +1717,8 @@
         self.block_stack = []
         self.lineno = code_options["co_firstlineno"]
         self.kw_names = None
+        self.accept_prefix_inst = True
+        self.prefix_insts = []
 
         # Properties of the input/output code
         self.instructions: List[Instruction] = instructions