[dynamo, 3.12] remove LOAD_METHOD, update LOAD_ATTR (#122356)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122356
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335, #122354, #122355
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index c7deb56..3a67854 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3842,7 +3842,8 @@
fn = locals["fn"]
orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
self.assertIn("EXTENDED_ARG", orig_inst_str)
- self.assertIn("LOAD_METHOD", orig_inst_str)
+ load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD"
+ self.assertIn(load_method_str, orig_inst_str)
keys = bytecode_transformation.get_code_keys()
code_options = {k: getattr(fn.__code__, k) for k in keys}
result = bytecode_transformation.clean_and_assemble_instructions(
@@ -3852,7 +3853,7 @@
)
new_inst_str = "\n".join(list(map(str, result[0])))
self.assertIn("EXTENDED_ARG", new_inst_str)
- self.assertIn("LOAD_METHOD", new_inst_str)
+ self.assertIn(load_method_str, new_inst_str)
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
self.assertEqual(len(l1), len(l2))
for p1, p2 in zip(l1, l2):
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 9856a09..89e7ee2 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -218,6 +218,28 @@
return [create_instruction("CALL_METHOD", arg=nargs)]
+def create_load_attr(name) -> Instruction:
+ # in 3.12, create a LOAD_ATTR instruction with the low bit unset
+ return Instruction(
+ opcode=dis.opmap["LOAD_ATTR"],
+ opname="LOAD_ATTR",
+ arg=False, # lowbit for 3.12
+ argval=name,
+ )
+
+
+def create_load_method(name) -> Instruction:
+ if sys.version_info >= (3, 12):
+ # in 3.12, create a LOAD_ATTR instruction with the low bit set
+ return Instruction(
+ opcode=dis.opmap["LOAD_ATTR"],
+ opname="LOAD_ATTR",
+ arg=True, # lowbit for 3.12
+ argval=name,
+ )
+ return create_instruction("LOAD_METHOD", argval=name)
+
+
def lnotab_writer(
lineno: int, byteno: int = 0
) -> Tuple[List[int], Callable[[int, int], None]]:
@@ -762,6 +784,7 @@
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
+ assert sys.version_info < (3, 11)
rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
for inst in instructions:
if inst.opname in rewrites:
@@ -963,6 +986,16 @@
)
else:
instructions[i].arg = names[instructions[i].argval]
+ elif instructions[i].opname == "LOAD_ATTR":
+ # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL
+ assert instructions[i].arg is not None
+ assert instructions[i].argval is not _NotProvided
+ if sys.version_info >= (3, 12):
+ instructions[i].arg = (names[instructions[i].argval] << 1) + (
+ cast(int, instructions[i].arg) % 2
+ )
+ else:
+ instructions[i].arg = names[instructions[i].argval]
elif instructions[i].opcode in HAS_LOCAL:
if should_compute_arg():
instructions[i].arg = varnames[instructions[i].argval]
diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py
index e97c856..4666a48 100644
--- a/torch/_dynamo/codegen.py
+++ b/torch/_dynamo/codegen.py
@@ -12,7 +12,9 @@
create_call_function,
create_dup_top,
create_instruction,
+ create_load_attr,
create_load_global,
+ create_load_method,
create_rot_n,
Instruction,
)
@@ -261,12 +263,12 @@
def create_load_method(self, name):
self.tx.output.update_co_names(name)
- return create_instruction("LOAD_METHOD", argval=name)
+ return create_load_method(name)
def create_load_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
- return create_instruction("LOAD_ATTR", argval=name)
+ return create_load_attr(name)
def load_attr(self, name):
self.append_output(self.create_load_attr(name))
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index b0b565a..3df6616 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -10,6 +10,7 @@
create_dup_top,
create_instruction,
create_jump_absolute,
+ create_load_method,
Instruction,
InstructionExnTabEntry,
transform_code_object,
@@ -70,7 +71,7 @@
*create_call_function(len(load_args), True),
create_instruction("STORE_FAST", argval=ctx_name),
create_instruction("LOAD_FAST", argval=ctx_name),
- create_instruction("LOAD_METHOD", argval="__enter__"),
+ create_load_method("__enter__"),
*create_call_method(0),
create_instruction("POP_TOP"),
]
@@ -94,7 +95,7 @@
def create_reset():
return [
create_instruction("LOAD_FAST", argval=ctx_name),
- create_instruction("LOAD_METHOD", argval="__exit__"),
+ create_load_method("__exit__"),
create_instruction("LOAD_CONST", argval=None),
create_dup_top(),
create_dup_top(),
diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py
index 3918476..26f4c31 100644
--- a/torch/_dynamo/side_effects.py
+++ b/torch/_dynamo/side_effects.py
@@ -9,6 +9,7 @@
create_call_function,
create_call_method,
create_instruction,
+ create_load_method,
)
from .codegen import PyCodegen
from .exc import unimplemented
@@ -359,9 +360,7 @@
for ctx, args in self.save_for_backward:
cg(ctx.source)
- cg.extend_output(
- [create_instruction("LOAD_METHOD", argval="save_for_backward")]
- )
+ cg.extend_output([create_load_method("save_for_backward")])
for arg in args:
cg(arg)
cg.extend_output(
@@ -460,11 +459,11 @@
cg.tx.output.update_co_names("update")
cg(var.mutable_local.source) # type: ignore[attr-defined]
- cg.extend_output([create_instruction("LOAD_METHOD", argval="update")])
+ cg.extend_output([create_load_method("update")])
cg(var, allow_cache=False)
cg(var.mutable_local.source) # type: ignore[attr-defined]
- cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")])
+ cg.extend_output([create_load_method("clear")])
suffixes.append(
[
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 2a72e40..45862a4 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1099,7 +1099,7 @@
def IMPORT_FROM(self, inst):
self.DUP_TOP(inst)
- self.LOAD_ATTR(inst)
+ self._load_attr(inst)
def load_builtin(self, inst):
if inst.argval not in self.f_builtins:
@@ -1263,7 +1263,7 @@
arg = inst.argval[0]
argval = self.code_options["co_names"][arg]
if sys.version_info < (3, 11):
- self.LOAD_ATTR(dataclasses.replace(inst, argval=argval))
+ self._load_attr(dataclasses.replace(inst, argval=argval))
else:
self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
@@ -1271,10 +1271,10 @@
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
arg = inst.argval[0]
argval = self.code_options["co_names"][arg]
- self.LOAD_ATTR(dataclasses.replace(inst, argval=argval))
+ self._load_attr(dataclasses.replace(inst, argval=argval))
def LOAD_METHOD(self, inst):
- self.LOAD_ATTR(inst)
+ self._load_attr(inst)
obj = self.pop()
if sys.version_info >= (3, 11):
# always follow the NULL + fn convention, since if obj
@@ -1293,13 +1293,20 @@
fn = self.pop()
self.call_function(fn, args, {})
- def LOAD_ATTR(self, inst):
+ def _load_attr(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
self, [obj, ConstantVariable.create(inst.argval)], {}
)
self.push(result)
+ def LOAD_ATTR(self, inst):
+ if sys.version_info >= (3, 12):
+ if inst.arg % 2:
+ self.LOAD_METHOD(inst)
+ return
+ self._load_attr(inst)
+
def STORE_ATTR(self, inst):
speculation = self.speculate()
if speculation.failed:
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 1dcbce4..07df439 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -14,6 +14,7 @@
create_call_function,
create_call_method,
create_instruction,
+ create_load_method,
)
from ..eval_frame import skip_code
@@ -428,7 +429,7 @@
codegen(self.dv_dict)
codegen.extend_output(
[
- create_instruction("LOAD_METHOD", argval=self.kv),
+ create_load_method(self.kv),
*create_call_method(0),
]
)