[dynamo, 3.12] add LOAD_SUPER_ATTR (#122738)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122738
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335, #122354, #122355, #122356, #122449, #122455, #122456, #122530, #122737
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 07f42b3..868d52c 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -97,8 +97,12 @@
`argval` or `target`.
Do not use for LOAD_GLOBAL - use create_load_global instead.
+ Do not use for LOAD_ATTR - use create_load_attr instead.
+ Do not use for LOAD_SUPER_ATTR - if you need to create this instruction,
+ implement a create_load_super_attr function.
"""
- assert name != "LOAD_GLOBAL"
+ if name in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR"):
+ raise RuntimeError(f"cannot create_instruction with {name}")
cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None)
if cnt > 1:
raise RuntimeError(
@@ -191,7 +195,7 @@
(assume `math` is available in the global scope),
create_load_global("math", True) # pushes a null
- create_instruction("LOAD_ATTR", argval="sqrt")
+ create_load_attr("sqrt")
create_instruction("LOAD_CONST", argval=25)
create_call_function(1, False)
"""
@@ -999,6 +1003,12 @@
)
else:
instructions[i].arg = names[instructions[i].argval]
+ elif instructions[i].opname == "LOAD_SUPER_ATTR":
+ assert instructions[i].arg is not None
+ assert instructions[i].argval is not _NotProvided
+ instructions[i].arg = (names[instructions[i].argval] << 2) + (
+ cast(int, instructions[i].arg) % 4
+ )
elif instructions[i].opcode in HAS_LOCAL:
if should_compute_arg():
instructions[i].arg = varnames[instructions[i].argval]
@@ -1130,7 +1140,8 @@
remove_jump_if_none(instructions)
update_offsets(instructions)
devirtualize_jumps(instructions)
- explicit_super(code, instructions)
+ if sys.version_info < (3, 12):
+ explicit_super(code, instructions)
return instructions
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index df3d2ed..c4a8332 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1894,6 +1894,14 @@
self.LOAD_FAST(inst)
self.symbolic_locals[inst.argval] = NullVariable()
+ def LOAD_SUPER_ATTR(self, inst):
+ super_vt, cls_vt, self_vt = self.popn(3)
+ self.call_function(super_vt, [cls_vt, self_vt], {})
+ if inst.arg & 1:
+ self.LOAD_METHOD(inst)
+ else:
+ self._load_attr(inst)
+
def is_non_empty_graph(self):
if self.output.count_calls() > 1:
# perf optimization only