[dynamo, 3.12] add LOAD_SUPER_ATTR (#122738)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122738
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335, #122354, #122355, #122356, #122449, #122455, #122456, #122530, #122737
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 07f42b3..868d52c 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -97,8 +97,12 @@
     `argval` or `target`.
 
     Do not use for LOAD_GLOBAL - use create_load_global instead.
+    Do not use for LOAD_ATTR - use create_load_attr instead.
+    Do not use for LOAD_SUPER_ATTR - if you need to create this instruction,
+        implement a create_load_super_attr function.
     """
-    assert name != "LOAD_GLOBAL"
+    if name in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR"):
+        raise RuntimeError(f"cannot create_instruction with {name}")
     cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None)
     if cnt > 1:
         raise RuntimeError(
@@ -191,7 +195,7 @@
     (assume `math` is available in the global scope),
 
     create_load_global("math", True)  # pushes a null
-    create_instruction("LOAD_ATTR", argval="sqrt")
+    create_load_attr("sqrt")
     create_instruction("LOAD_CONST", argval=25)
     create_call_function(1, False)
     """
@@ -999,6 +1003,12 @@
                 )
             else:
                 instructions[i].arg = names[instructions[i].argval]
+        elif instructions[i].opname == "LOAD_SUPER_ATTR":
+            assert instructions[i].arg is not None
+            assert instructions[i].argval is not _NotProvided
+            instructions[i].arg = (names[instructions[i].argval] << 2) + (
+                cast(int, instructions[i].arg) % 4
+            )
         elif instructions[i].opcode in HAS_LOCAL:
             if should_compute_arg():
                 instructions[i].arg = varnames[instructions[i].argval]
@@ -1130,7 +1140,8 @@
             remove_jump_if_none(instructions)
             update_offsets(instructions)
             devirtualize_jumps(instructions)
-        explicit_super(code, instructions)
+        if sys.version_info < (3, 12):
+            explicit_super(code, instructions)
     return instructions
 
 
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index df3d2ed..c4a8332 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1894,6 +1894,14 @@
             self.LOAD_FAST(inst)
         self.symbolic_locals[inst.argval] = NullVariable()
 
+    def LOAD_SUPER_ATTR(self, inst):
+        super_vt, cls_vt, self_vt = self.popn(3)
+        self.call_function(super_vt, [cls_vt, self_vt], {})
+        if inst.arg & 1:
+            self.LOAD_METHOD(inst)
+        else:
+            self._load_attr(inst)
+
     def is_non_empty_graph(self):
         if self.output.count_calls() > 1:
             # perf optimization only