add a debug api to extract cache entry from code (#106673)

Per the discussion with @jansel  in https://dev-discuss.pytorch.org/t/how-are-guards-installed-on-frames-that-are-transient-objects/1415/7 , guards and compiled code live in `co_extra` field in pycodeobject, which cannot be accessed in a trivial way. This PR tries to add a debug API to extract the data from that field, which can make debugging torchdynamo much easier.

The API is intended to be used for debug only, and should have no compatibility issues with the current system.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106673
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 2377290..9a7b1b1 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -27,6 +27,7 @@
 import torch.onnx.operators
 from torch._C import FileCheck
 from torch._dynamo import allow_in_graph, bytecode_analysis, bytecode_transformation
+from torch._dynamo.eval_frame import _debug_get_cache_entry_list
 from torch._dynamo.exc import Unsupported
 from torch._dynamo.source import GetItemSource, LocalSource
 from torch._dynamo.testing import (
@@ -82,6 +83,25 @@
 
 
 class MiscTests(torch._dynamo.test_case.TestCase):
+    def test_get_cache_entry(self):
+        def f(x):
+            return x + 1
+
+        torch.compile(f)(torch.randn(5, 5, 5))
+        entries = _debug_get_cache_entry_list(f.__code__)
+        self.assertTrue(len(entries) > 0)
+
+        def g(x):
+            return x + 2
+
+        entries = _debug_get_cache_entry_list(g.__code__)
+        self.assertTrue(len(entries) == 0)
+
+        try:
+            _debug_get_cache_entry_list(1)
+        except TypeError as e:
+            self.assertIn("expected a code object!", str(e))
+
     def test_boolarg(self):
         def boolarg(aa, bb, flag):
             if flag:
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index f738320..44d713b 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -14,6 +14,7 @@
 import types
 import warnings
 import weakref
+from collections import namedtuple
 from enum import Enum
 from os.path import dirname, join
 from typing import (
@@ -99,6 +100,17 @@
 }
 
 
+CacheEntry = namedtuple("CacheEntry", "check_fn, code")
+
+
+def _debug_get_cache_entry_list(code: types.CodeType) -> List[CacheEntry]:
+    """
+    Given a code object, retrieve the cache entries stored in this code.
+    """
+    cache_list = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
+    return list(map(CacheEntry._make, cache_list))
+
+
 class OptimizedModule(torch.nn.Module):
     """
     Wraps the original nn.Module object and later patches its
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index da44bdd..aa0a97d 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -320,6 +320,40 @@
   return extra;
 }
 
+PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
+  PyObject* object;
+  if (!PyArg_ParseTuple(args, "O", &object)) {
+    return NULL;
+  }
+  if (!PyCode_Check(object)) {
+    PyErr_SetString(PyExc_TypeError, "expected a code object!");
+    return NULL;
+  }
+  PyCodeObject* code = (PyCodeObject*)object;
+
+  CacheEntry* current_node = get_cache_entry(code);
+
+  PyObject* outer_list = PyList_New(0);
+  if (!outer_list) {
+    return NULL;  // Return NULL if failed to create list
+  }
+  while (current_node != NULL && current_node != SKIP_CODE) {
+    // Creating a new Python tuple for the check_fn and code of current CacheEntry
+    PyObject* inner_list = PyTuple_Pack(2, current_node->check_fn, current_node->code);
+    int flag = PyList_Append(outer_list, inner_list);  // Add the inner list to the outer list
+    Py_DECREF(inner_list);  // Decrement our own reference
+    if (flag < 0) {
+      Py_DECREF(outer_list);  // Clean up if failed to append
+      return NULL;
+    }
+
+    // Move to the next node in the linked list
+    current_node = current_node->next;
+  }
+  // Return the outer list
+  return outer_list;
+}
+
 inline static void set_cache_entry(PyCodeObject* code, CacheEntry* extra) {
   // TODO(jansel): would it be faster to bypass this?
   _PyCode_SetExtra((PyObject*)code, cache_entry_extra_index, extra);
@@ -867,6 +901,7 @@
     {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
     {"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
     {"clear_profiler_hooks", clear_profiler_hooks, METH_NOARGS, NULL},
+    {"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
     {NULL, NULL, 0, NULL}};
 
 static struct PyModuleDef _module = {