[dynamo 3.11] support prefix instructions MAKE_CELL, COPY_FREE_VARS, RETURN_GENERATOR, RESUME (#96506)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96506
Approved by: https://github.com/jansel
diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py
index 38700c2..efb7116 100644
--- a/torch/_dynamo/bytecode_analysis.py
+++ b/torch/_dynamo/bytecode_analysis.py
@@ -111,6 +111,8 @@
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
state.writes.add(inst.argval)
+ elif inst.opname == "MAKE_CELL":
+ pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
if inst.opcode in JUMP_OPCODES:
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index dbc217a..e6e08b5 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -2,6 +2,7 @@
import contextlib
import dataclasses
+import dis
import functools
import inspect
import logging
@@ -341,11 +342,21 @@
super().__init__(callback=None)
+def first_real_inst_idx(code):
+ if sys.version_info < (3, 11):
+ return 0
+ for inst in dis.get_instructions(code):
+ if inst.opname == "RESUME":
+ return inst.offset // 2
+ raise RuntimeError("RESUME instruction not found in code")
+
+
def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_size):
if (
- frame.f_lasti >= 0
+ # TODO: the first condition is not covered by any test
+ frame.f_lasti >= first_real_inst_idx(frame.f_code)
or skipfiles.check(frame.f_code.co_filename)
or config.disable
):
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 342fa4b1..9cef09d 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -5,6 +5,7 @@
import logging
import operator
import re
+import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
@@ -524,6 +525,27 @@
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")
+ prefix_insts: List[Instruction] = []
+ if sys.version_info >= (3, 11):
+ # prefix instructions (Python 3.11+)
+ for inst in tx.prefix_insts:
+ if inst.opname == "MAKE_CELL":
+ prefix_insts.append(
+ create_instruction("MAKE_CELL", argval=inst.argval)
+ )
+ elif inst.opname == "COPY_FREE_VARS":
+ prefix_insts.append(
+ create_instruction(
+ "COPY_FREE_VARS", len(tx.code_options["co_freevars"])
+ )
+ )
+ else:
+ prefix_insts.append(inst)
+
+ def append_prefix_insts():
+ self.add_output_instructions(prefix_insts)
+ prefix_insts.clear()
+
for block in reversed(tx.block_stack):
block.exit(tx)
@@ -551,6 +573,7 @@
# to handle random calls
if len(tx.random_calls) > 0:
+ append_prefix_insts()
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
@@ -583,6 +606,7 @@
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
+ append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
@@ -616,6 +640,7 @@
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
+ append_prefix_insts()
self.add_output_instructions(output + pass2.get_instructions())
# restore all the live local vars
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index 14f051b..e51c0b6 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -265,6 +265,11 @@
(target,) = [i for i in instructions if i.offset == offset]
prefix = []
+ if sys.version_info >= (3, 11):
+ if freevars:
+ prefix.append(create_instruction("COPY_FREE_VARS", len(freevars)))
+ prefix.append(create_instruction("RESUME", 0))
+
cleanup = []
hooks = {fn.stack_index: fn for fn in setup_fns}
null_idxes_i = 0
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 313aacb..f533214 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -420,6 +420,8 @@
lineno: int
mutated_closure_cell_contents: Set[str]
kw_names: Optional[ConstantVariable]
+ accept_prefix_inst: bool
+ prefix_insts: List[Instruction]
checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
random_calls: List[
@@ -1502,9 +1504,12 @@
INPLACE_OR = stack_op(operator.ior)
# 3.11 opcodes
- # note: passed opcodes are intentional
def RESUME(self, inst):
- pass
+ if inst.arg == 0:
+ self.append_prefix_inst(inst)
+ self.accept_prefix_inst = False
+ else:
+ assert not self.accept_prefix_inst
def BINARY_OP(self, inst):
if sys.version_info >= (3, 11):
@@ -1600,6 +1605,19 @@
self.push(exit)
self.push(ctx.enter(self))
+ def append_prefix_inst(self, inst):
+ assert self.accept_prefix_inst
+ self.prefix_insts.append(inst)
+
+ def MAKE_CELL(self, inst):
+ self.append_prefix_inst(inst)
+
+ def COPY_FREE_VARS(self, inst):
+ self.append_prefix_inst(inst)
+
+ def RETURN_GENERATOR(self, inst):
+ self.append_prefix_inst(inst)
+
def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return InstructionTranslatorGraphState(
@@ -1699,6 +1717,8 @@
self.block_stack = []
self.lineno = code_options["co_firstlineno"]
self.kw_names = None
+ self.accept_prefix_inst = True
+ self.prefix_insts = []
# Properties of the input/output code
self.instructions: List[Instruction] = instructions