[dynamo, 3.12] remove LOAD_METHOD, update LOAD_ATTR (#122356)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122356
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335, #122354, #122355
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index c7deb56..3a67854 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3842,7 +3842,8 @@
         fn = locals["fn"]
         orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
         self.assertIn("EXTENDED_ARG", orig_inst_str)
-        self.assertIn("LOAD_METHOD", orig_inst_str)
+        load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD"
+        self.assertIn(load_method_str, orig_inst_str)
         keys = bytecode_transformation.get_code_keys()
         code_options = {k: getattr(fn.__code__, k) for k in keys}
         result = bytecode_transformation.clean_and_assemble_instructions(
@@ -3852,7 +3853,7 @@
         )
         new_inst_str = "\n".join(list(map(str, result[0])))
         self.assertIn("EXTENDED_ARG", new_inst_str)
-        self.assertIn("LOAD_METHOD", new_inst_str)
+        self.assertIn(load_method_str, new_inst_str)
         l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
         self.assertEqual(len(l1), len(l2))
         for p1, p2 in zip(l1, l2):
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 9856a09..89e7ee2 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -218,6 +218,28 @@
     return [create_instruction("CALL_METHOD", arg=nargs)]
 
 
+def create_load_attr(name) -> Instruction:
+    # in 3.12, create a LOAD_ATTR instruction with the low bit unset
+    return Instruction(
+        opcode=dis.opmap["LOAD_ATTR"],
+        opname="LOAD_ATTR",
+        arg=False,  # lowbit for 3.12
+        argval=name,
+    )
+
+
+def create_load_method(name) -> Instruction:
+    if sys.version_info >= (3, 12):
+        # in 3.12, create a LOAD_ATTR instruction with the low bit set
+        return Instruction(
+            opcode=dis.opmap["LOAD_ATTR"],
+            opname="LOAD_ATTR",
+            arg=True,  # lowbit for 3.12
+            argval=name,
+        )
+    return create_instruction("LOAD_METHOD", argval=name)
+
+
 def lnotab_writer(
     lineno: int, byteno: int = 0
 ) -> Tuple[List[int], Callable[[int, int], None]]:
@@ -762,6 +784,7 @@
 
 def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
     """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
+    assert sys.version_info < (3, 11)
     rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
     for inst in instructions:
         if inst.opname in rewrites:
@@ -963,6 +986,16 @@
                 )
             else:
                 instructions[i].arg = names[instructions[i].argval]
+        elif instructions[i].opname == "LOAD_ATTR":
+            # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL
+            assert instructions[i].arg is not None
+            assert instructions[i].argval is not _NotProvided
+            if sys.version_info >= (3, 12):
+                instructions[i].arg = (names[instructions[i].argval] << 1) + (
+                    cast(int, instructions[i].arg) % 2
+                )
+            else:
+                instructions[i].arg = names[instructions[i].argval]
         elif instructions[i].opcode in HAS_LOCAL:
             if should_compute_arg():
                 instructions[i].arg = varnames[instructions[i].argval]
diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py
index e97c856..4666a48 100644
--- a/torch/_dynamo/codegen.py
+++ b/torch/_dynamo/codegen.py
@@ -12,7 +12,9 @@
     create_call_function,
     create_dup_top,
     create_instruction,
+    create_load_attr,
     create_load_global,
+    create_load_method,
     create_rot_n,
     Instruction,
 )
@@ -261,12 +263,12 @@
 
     def create_load_method(self, name):
         self.tx.output.update_co_names(name)
-        return create_instruction("LOAD_METHOD", argval=name)
+        return create_load_method(name)
 
     def create_load_attr(self, name) -> Instruction:
         if name not in self.code_options["co_names"]:
             self.code_options["co_names"] += (name,)
-        return create_instruction("LOAD_ATTR", argval=name)
+        return create_load_attr(name)
 
     def load_attr(self, name):
         self.append_output(self.create_load_attr(name))
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index b0b565a..3df6616 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -10,6 +10,7 @@
     create_dup_top,
     create_instruction,
     create_jump_absolute,
+    create_load_method,
     Instruction,
     InstructionExnTabEntry,
     transform_code_object,
@@ -70,7 +71,7 @@
             *create_call_function(len(load_args), True),
             create_instruction("STORE_FAST", argval=ctx_name),
             create_instruction("LOAD_FAST", argval=ctx_name),
-            create_instruction("LOAD_METHOD", argval="__enter__"),
+            create_load_method("__enter__"),
             *create_call_method(0),
             create_instruction("POP_TOP"),
         ]
@@ -94,7 +95,7 @@
         def create_reset():
             return [
                 create_instruction("LOAD_FAST", argval=ctx_name),
-                create_instruction("LOAD_METHOD", argval="__exit__"),
+                create_load_method("__exit__"),
                 create_instruction("LOAD_CONST", argval=None),
                 create_dup_top(),
                 create_dup_top(),
diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py
index 3918476..26f4c31 100644
--- a/torch/_dynamo/side_effects.py
+++ b/torch/_dynamo/side_effects.py
@@ -9,6 +9,7 @@
     create_call_function,
     create_call_method,
     create_instruction,
+    create_load_method,
 )
 from .codegen import PyCodegen
 from .exc import unimplemented
@@ -359,9 +360,7 @@
 
         for ctx, args in self.save_for_backward:
             cg(ctx.source)
-            cg.extend_output(
-                [create_instruction("LOAD_METHOD", argval="save_for_backward")]
-            )
+            cg.extend_output([create_load_method("save_for_backward")])
             for arg in args:
                 cg(arg)
             cg.extend_output(
@@ -460,11 +459,11 @@
                 cg.tx.output.update_co_names("update")
 
                 cg(var.mutable_local.source)  # type: ignore[attr-defined]
-                cg.extend_output([create_instruction("LOAD_METHOD", argval="update")])
+                cg.extend_output([create_load_method("update")])
                 cg(var, allow_cache=False)
 
                 cg(var.mutable_local.source)  # type: ignore[attr-defined]
-                cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")])
+                cg.extend_output([create_load_method("clear")])
 
                 suffixes.append(
                     [
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 2a72e40..45862a4 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1099,7 +1099,7 @@
 
     def IMPORT_FROM(self, inst):
         self.DUP_TOP(inst)
-        self.LOAD_ATTR(inst)
+        self._load_attr(inst)
 
     def load_builtin(self, inst):
         if inst.argval not in self.f_builtins:
@@ -1263,7 +1263,7 @@
         arg = inst.argval[0]
         argval = self.code_options["co_names"][arg]
         if sys.version_info < (3, 11):
-            self.LOAD_ATTR(dataclasses.replace(inst, argval=argval))
+            self._load_attr(dataclasses.replace(inst, argval=argval))
         else:
             self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
 
@@ -1271,10 +1271,10 @@
         self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
         arg = inst.argval[0]
         argval = self.code_options["co_names"][arg]
-        self.LOAD_ATTR(dataclasses.replace(inst, argval=argval))
+        self._load_attr(dataclasses.replace(inst, argval=argval))
 
     def LOAD_METHOD(self, inst):
-        self.LOAD_ATTR(inst)
+        self._load_attr(inst)
         obj = self.pop()
         if sys.version_info >= (3, 11):
             # always follow the NULL + fn convention, since if obj
@@ -1293,13 +1293,20 @@
         fn = self.pop()
         self.call_function(fn, args, {})
 
-    def LOAD_ATTR(self, inst):
+    def _load_attr(self, inst):
         obj = self.pop()
         result = BuiltinVariable(getattr).call_function(
             self, [obj, ConstantVariable.create(inst.argval)], {}
         )
         self.push(result)
 
+    def LOAD_ATTR(self, inst):
+        if sys.version_info >= (3, 12):
+            if inst.arg % 2:
+                self.LOAD_METHOD(inst)
+                return
+        self._load_attr(inst)
+
     def STORE_ATTR(self, inst):
         speculation = self.speculate()
         if speculation.failed:
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 1dcbce4..07df439 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -14,6 +14,7 @@
     create_call_function,
     create_call_method,
     create_instruction,
+    create_load_method,
 )
 from ..eval_frame import skip_code
 
@@ -428,7 +429,7 @@
         codegen(self.dv_dict)
         codegen.extend_output(
             [
-                create_instruction("LOAD_METHOD", argval=self.kv),
+                create_load_method(self.kv),
                 *create_call_method(0),
             ]
         )