[dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython (#122146)
Python 3.12 changed a few things with how `_PyInterpreterFrame`s are allocated and freed:
- Frames are now required to be placed on the Python frame stack. In 3.11, we could allocate frames anywhere in memory. In 3.12, we now need to use `THP_PyThreadState_BumpFramePointerSlow`/`push_chunk`/`allocate_chunk`. This method of allocating/freeing frames is also compatible with 3.11.
- The eval frame function is now responsible for clearing the frame (see https://docs.python.org/3/whatsnew/changelog.html#id128, the point about "...which now clear the frame.")
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122146
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 8a2ae73..c7deb56 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -47,6 +47,7 @@
expectedFailureDynamic,
same,
skipIfNotPy311,
+ skipIfPy312,
unsupported,
xfailIfPy311,
)
@@ -6604,6 +6605,7 @@
for inst in insts:
self.assertNotIn("_NONE", inst.opname)
+ @skipIfPy312
@skipIfNotPy311
def test_py311_jump_offset(self):
new_inst = bytecode_transformation.create_instruction
diff --git a/torch/__init__.py b/torch/__init__.py
index 476684e..85ba4fa 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1867,7 +1867,6 @@
"""
_C._log_api_usage_once("torch.compile")
- # Temporary until we get proper support for python 3.12
if sys.version_info >= (3, 12):
raise RuntimeError("Dynamo is not supported on Python 3.12+")
diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py
index fac20cf..0a0e7f6 100644
--- a/torch/_dynamo/testing.py
+++ b/torch/_dynamo/testing.py
@@ -348,6 +348,12 @@
return fn
+def skipIfPy312(fn):
+ if sys.version_info >= (3, 12):
+ return unittest.skip(fn)
+ return fn
+
+
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn):
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 43ac713..50bfe3a 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -2814,5 +2814,7 @@
# compile workers created not being able to be shut down inside
# shutdown_compile_workers(). This may cause significant QPS drop.
log.info("Do not call AsyncCompile.warm_pool() because TorchTNT is in use.")
+elif sys.version_info >= (3, 12):
+ log.info("AsyncCompile.warm_pool() is broken on 3.12+.")
else:
AsyncCompile.warm_pool()
diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c
index cd7756b..41f8145 100644
--- a/torch/csrc/dynamo/cpython_defs.c
+++ b/torch/csrc/dynamo/cpython_defs.c
@@ -65,6 +65,188 @@
return 0;
}
+#if IS_PYTHON_3_12_PLUS
+
+// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136
+// Initialize frame free variables if needed
+static void
+frame_init_get_vars(_PyInterpreterFrame *frame)
+{
+ // COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
+ // here:
+ PyCodeObject *co = frame->f_code;
+ int lasti = _PyInterpreterFrame_LASTI(frame);
+ if (!(lasti < 0 && _PyCode_CODE(co)->op.code == COPY_FREE_VARS
+ && PyFunction_Check(frame->f_funcobj)))
+ {
+ /* Free vars are initialized */
+ return;
+ }
+
+ /* Free vars have not been initialized -- Do that */
+ PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
+ int offset = PyCode_GetFirstFree(co);
+ for (int i = 0; i < co->co_nfreevars; ++i) {
+ PyObject *o = PyTuple_GET_ITEM(closure, i);
+ frame->localsplus[offset + i] = Py_NewRef(o);
+ }
+ // COPY_FREE_VARS doesn't have inline CACHEs, either:
+ frame->prev_instr = _PyCode_CODE(frame->f_code);
+}
+
+// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162
+static int
+frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i,
+ PyObject **pvalue)
+{
+ _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
+
+ /* If the namespace is unoptimized, then one of the
+ following cases applies:
+ 1. It does not contain free variables, because it
+ uses import * or is a top-level namespace.
+ 2. It is a class namespace.
+ We don't want to accidentally copy free variables
+ into the locals dict used by the class.
+ */
+ if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) {
+ return 0;
+ }
+
+ PyObject *value = frame->localsplus[i];
+ if (frame->stacktop) {
+ if (kind & CO_FAST_FREE) {
+ // The cell was set by COPY_FREE_VARS.
+ CHECK(value != NULL && PyCell_Check(value));
+ value = PyCell_GET(value);
+ }
+ else if (kind & CO_FAST_CELL) {
+ // Note that no *_DEREF ops can happen before MAKE_CELL
+ // executes. So there's no need to duplicate the work
+ // that MAKE_CELL would otherwise do later, if it hasn't
+ // run yet.
+ if (value != NULL) {
+ if (PyCell_Check(value) &&
+ THP_PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) {
+ // (likely) MAKE_CELL must have executed already.
+ value = PyCell_GET(value);
+ }
+ // (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL),
+ // with the initial value set when the frame was created...
+ // (unlikely) ...or it was set to some initial value by
+ // an earlier call to PyFrame_LocalsToFast().
+ }
+ }
+ }
+ else {
+ CHECK(value == NULL);
+ }
+ *pvalue = value;
+ return 1;
+}
+
+// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213
+static PyObject *
+THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
+{
+ /* Merge fast locals into f->f_locals */
+ PyObject *locals = frame->f_locals;
+ if (locals == NULL) {
+ locals = frame->f_locals = PyDict_New();
+ if (locals == NULL) {
+ return NULL;
+ }
+ }
+ PyObject *hidden = NULL;
+
+ /* If include_hidden, "hidden" fast locals (from inlined comprehensions in
+ module/class scopes) will be included in the returned dict, but not in
+ frame->f_locals; the returned dict will be a modified copy. Non-hidden
+ locals will still be updated in frame->f_locals. */
+ if (include_hidden) {
+ hidden = PyDict_New();
+ if (hidden == NULL) {
+ return NULL;
+ }
+ }
+
+ frame_init_get_vars(frame);
+
+ PyCodeObject *co = frame->f_code;
+ for (int i = 0; i < co->co_nlocalsplus; i++) {
+ PyObject *value; // borrowed reference
+ if (!frame_get_var(frame, co, i, &value)) {
+ continue;
+ }
+
+ PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i);
+ _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
+ if (kind & CO_FAST_HIDDEN) {
+ if (include_hidden && value != NULL) {
+ if (PyObject_SetItem(hidden, name, value) != 0) {
+ goto error;
+ }
+ }
+ continue;
+ }
+ if (value == NULL) {
+ if (PyObject_DelItem(locals, name) != 0) {
+ if (PyErr_ExceptionMatches(PyExc_KeyError)) {
+ PyErr_Clear();
+ }
+ else {
+ goto error;
+ }
+ }
+ }
+ else {
+ if (PyObject_SetItem(locals, name, value) != 0) {
+ goto error;
+ }
+ }
+ }
+
+ if (include_hidden && PyDict_Size(hidden)) {
+ PyObject *innerlocals = PyDict_New();
+ if (innerlocals == NULL) {
+ goto error;
+ }
+ if (PyDict_Merge(innerlocals, locals, 1) != 0) {
+ Py_DECREF(innerlocals);
+ goto error;
+ }
+ if (PyDict_Merge(innerlocals, hidden, 1) != 0) {
+ Py_DECREF(innerlocals);
+ goto error;
+ }
+ locals = innerlocals;
+ }
+ else {
+ Py_INCREF(locals);
+ }
+ Py_CLEAR(hidden);
+
+ return locals;
+
+ error:
+ Py_XDECREF(hidden);
+ return NULL;
+}
+
+// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301
+int
+THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
+{
+ PyObject *locals = THP_PyFrame_GetLocals(frame, 0);
+ if (locals == NULL) {
+ return -1;
+ }
+ Py_DECREF(locals);
+ return 0;
+}
+
+#else
+
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
int
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
@@ -164,6 +346,8 @@
return 0;
}
+#endif
+
// e.g. COPY_FIELD(op, o, globals) becomes
// PY_XINCREF((o)->func_globals);
// (op)->func_globals = (o)->func_globals;
@@ -201,6 +385,9 @@
op->func_weakreflist = NULL;
COPY_FIELD(op, o, module);
COPY_FIELD(op, o, annotations);
+ #if IS_PYTHON_3_12_PLUS
+ COPY_FIELD(op, o, typeparams);
+ #endif
op->vectorcall = o->vectorcall;
op->func_version = o->func_version;
PyObject_GC_Track(op);
@@ -351,6 +538,7 @@
}
Py_XDECREF(frame->frame_obj);
Py_XDECREF(frame->f_locals);
+ // DYNAMO: additional field for 3.12
#if IS_PYTHON_3_12_PLUS
Py_DECREF(frame->f_funcobj);
#else
@@ -359,4 +547,104 @@
Py_DECREF(frame->f_code);
}
+// https://github.com/python/cpython/blob/051b8a2589ff28f0194c3701b21f729444691752/Python/pystate.c#L728
+static _PyStackChunk*
+allocate_chunk(int size_in_bytes, _PyStackChunk* previous)
+{
+ CHECK(size_in_bytes % sizeof(PyObject **) == 0);
+ // DYNAMO: _PyStackChunk is a regular C struct, so
+ // it should be safe to use system malloc over Python malloc, e.g. _PyObject_VirtualAlloc
+ _PyStackChunk *res = malloc(size_in_bytes);
+ if (res == NULL) {
+ return NULL;
+ }
+ res->previous = previous;
+ res->size = size_in_bytes;
+ res->top = 0;
+ return res;
+}
+
+#define DATA_STACK_CHUNK_SIZE (16*1024)
+#define MINIMUM_OVERHEAD 1000
+
+// https://github.com/python/cpython/blob/051b8a2589ff28f0194c3701b21f729444691752/Python/pystate.c#L2182
+static PyObject **
+push_chunk(PyThreadState *tstate, int size)
+{
+ int allocate_size = DATA_STACK_CHUNK_SIZE;
+ while (allocate_size < (int)sizeof(PyObject*)*(size + MINIMUM_OVERHEAD)) {
+ allocate_size *= 2;
+ }
+ _PyStackChunk *new = allocate_chunk(allocate_size, tstate->datastack_chunk);
+ if (new == NULL) {
+ return NULL;
+ }
+ if (tstate->datastack_chunk) {
+ tstate->datastack_chunk->top = tstate->datastack_top -
+ &tstate->datastack_chunk->data[0];
+ }
+ tstate->datastack_chunk = new;
+ tstate->datastack_limit = (PyObject **)(((char *)new) + allocate_size);
+ // When new is the "root" chunk (i.e. new->previous == NULL), we can keep
+ // _PyThreadState_PopFrame from freeing it later by "skipping" over the
+ // first element:
+ PyObject **res = &new->data[new->previous == NULL];
+ tstate->datastack_top = res + size;
+ return res;
+}
+
+// https://github.com/python/cpython/blob/051b8a2589ff28f0194c3701b21f729444691752/Include/internal/pycore_frame.h#L199
+static inline bool
+THP_PyThreadState_HasStackSpace(PyThreadState *tstate, size_t size)
+{
+ CHECK(
+ (tstate->datastack_top == NULL && tstate->datastack_limit == NULL)
+ ||
+ (tstate->datastack_top != NULL && tstate->datastack_limit != NULL)
+ );
+ return tstate->datastack_top != NULL &&
+ size < (size_t)(tstate->datastack_limit - tstate->datastack_top);
+}
+
+// https://github.com/python/cpython/blob/051b8a2589ff28f0194c3701b21f729444691752/Python/pystate.c#L2207
+_PyInterpreterFrame *
+THP_PyThreadState_BumpFramePointerSlow(PyThreadState *tstate, size_t size)
+{
+ if (THP_PyThreadState_HasStackSpace(tstate, size)) {
+ _PyInterpreterFrame *res = (_PyInterpreterFrame *)tstate->datastack_top;
+ tstate->datastack_top += size;
+ return res;
+ }
+ if (size > INT_MAX/2) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ return (_PyInterpreterFrame *)push_chunk(tstate, (int)size);
+}
+
+// https://github.com/python/cpython/blob/051b8a2589ff28f0194c3701b21f729444691752/Python/pystate.c#L2222
+void
+THP_PyThreadState_PopFrame(PyThreadState *tstate, _PyInterpreterFrame * frame)
+{
+ CHECK(tstate->datastack_chunk);
+ PyObject **base = (PyObject **)frame;
+ if (base == &tstate->datastack_chunk->data[0]) {
+ _PyStackChunk *chunk = tstate->datastack_chunk;
+ _PyStackChunk *previous = chunk->previous;
+ // push_chunk ensures that the root chunk is never popped:
+ CHECK(previous);
+ tstate->datastack_top = &previous->data[previous->top];
+ tstate->datastack_chunk = previous;
+ // DYNAMO: free instead of _PyObject_VirtualFree
+ free(chunk);
+ tstate->datastack_limit = (PyObject **)(((char *)previous) + previous->size);
+ }
+ else {
+ CHECK(tstate->datastack_top);
+ CHECK(tstate->datastack_top >= base);
+ tstate->datastack_top = base;
+ }
+}
+
+
#endif
diff --git a/torch/csrc/dynamo/cpython_defs.h b/torch/csrc/dynamo/cpython_defs.h
index d9d3efa..a580de7 100644
--- a/torch/csrc/dynamo/cpython_defs.h
+++ b/torch/csrc/dynamo/cpython_defs.h
@@ -18,4 +18,12 @@
void THP_PyFrame_Clear(_PyInterpreterFrame* frame);
+_PyInterpreterFrame* THP_PyThreadState_BumpFramePointerSlow(
+ PyThreadState* tstate,
+ size_t size);
+
+void THP_PyThreadState_PopFrame(
+ PyThreadState* tstate,
+ _PyInterpreterFrame* frame);
+
#endif
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 828e492..72817d4 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -268,6 +268,17 @@
return res;
}
+static inline void clear_old_frame_if_python_312_plus(
+ PyThreadState* tstate,
+ THP_EVAL_API_FRAME_OBJECT* frame) {
+#if IS_PYTHON_3_12_PLUS
+
+ THP_PyFrame_Clear(frame);
+ THP_PyThreadState_PopFrame(tstate, frame);
+
+#endif
+}
+
inline static PyObject* eval_custom_code_impl(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
@@ -278,27 +289,24 @@
DEBUG_NULL_CHECK(frame);
DEBUG_NULL_CHECK(code);
- #if IS_PYTHON_3_11_PLUS
+#if IS_PYTHON_3_11_PLUS
// Generate Python function object and _PyInterpreterFrame in a way similar to
// https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
- #if IS_PYTHON_3_12_PLUS
- // Most of these don't exist in 3.12 anymore.
- // _PyFunction_CopyWithNewCode and _PyFrame_InitializeSpecials in particular
- PyFunctionObject* func;
- PyErr_SetString(PyExc_RuntimeError, "Dynamo is not supported in Python 3.12 yet");
- return NULL;
- #else
- PyFunctionObject* func = _PyFunction_CopyWithNewCode((PyFunctionObject*) frame->f_func, code);
+#if IS_PYTHON_3_12_PLUS
+ PyFunctionObject* old_func = (PyFunctionObject*) frame->f_funcobj;
+ size_t size = code->co_framesize;
+#else
+ PyFunctionObject* old_func = frame->f_func;
+ size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
+#endif
+
+ PyFunctionObject* func = _PyFunction_CopyWithNewCode(old_func, code);
if (func == NULL) {
return NULL;
}
- #endif
- size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
- // THP_EVAL_API_FRAME_OBJECT (_PyInterpreterFrame) is a regular C struct, so
- // it should be safe to use system malloc over Python malloc, e.g. PyMem_Malloc
- THP_EVAL_API_FRAME_OBJECT* shadow = malloc(size * sizeof(PyObject*));
+ THP_EVAL_API_FRAME_OBJECT* shadow = THP_PyThreadState_BumpFramePointerSlow(tstate, size);
if (shadow == NULL) {
Py_DECREF(func);
return NULL;
@@ -306,9 +314,11 @@
Py_INCREF(func);
// consumes reference to func
- #if !(IS_PYTHON_3_12_PLUS)
+#if IS_PYTHON_3_12_PLUS
+ _PyFrame_Initialize(shadow, func, NULL, code, 0);
+#else
_PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus);
- #endif
+#endif
PyObject** fastlocals_old = frame->localsplus;
PyObject** fastlocals_new = shadow->localsplus;
@@ -316,11 +326,14 @@
Py_ssize_t n_new = code->co_nlocalsplus;
// localsplus are XINCREF'd by default eval frame, so all values must be valid.
+#if !(IS_PYTHON_3_12_PLUS)
+ // _PyFrame_Initialize in 3.12 already does this
for (int i = 0; i < code->co_nlocalsplus; i++) {
fastlocals_new[i] = NULL;
}
+#endif
- #else
+#else
THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) {
@@ -332,7 +345,7 @@
Py_ssize_t n_old = frame->f_code->co_nlocals + PyCode_GetNFreevars(frame->f_code) + PyCode_GetNCellvars(frame->f_code);
Py_ssize_t n_new = code->co_nlocals + PyCode_GetNFreevars(code) + PyCode_GetNCellvars(code);
- #endif
+#endif
// ============== Initialize new frame from old frame ============
// Python internal for executing a function:
@@ -399,17 +412,17 @@
// conditional test to tell if a variable is not a cell variable
// this is straightforward in Python 3.11 and higher, as there are bit flags in `co_localspluskinds` to tell if a variable is a cell variable.
// in Python 3.10 and lower, essentially we are checking if a variable is a new local variable (because of the layout mentioned above, the first variable that is not cell variable is the first new local variable). the corresponding slot in `flocalsplus` is NULL for new local variables.
- #if IS_PYTHON_3_11_PLUS
+#if IS_PYTHON_3_11_PLUS
if(!(_PyLocals_GetKind(frame->f_code->co_localspluskinds, i) & CO_FAST_CELL))
{
break;
}
- #else
+#else
if(fastlocals_old[i] == NULL)
{
break;
}
- #endif
+#endif
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[j] = fastlocals_old[i];
@@ -417,17 +430,27 @@
PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
- #if IS_PYTHON_3_11_PLUS
+#if IS_PYTHON_3_12_PLUS
- THP_PyFrame_Clear(shadow);
- free(shadow);
+ // In 3.12, the frame evaluation function is responsible for
+ // clearing and popping the frame, so we manually do that on the
+ // old frame.
+ clear_old_frame_if_python_312_plus(tstate, frame);
Py_DECREF(func);
- #else
+#elif IS_PYTHON_3_11_PLUS
+
+ // In 3.11, shadow has is_entry set to true, so _PyEvalFrameClearAndPop is not called,
+ // so we manually clear and pop the shadow frame.
+ THP_PyFrame_Clear(shadow);
+ THP_PyThreadState_PopFrame(tstate, shadow);
+ Py_DECREF(func);
+
+#else
Py_DECREF(shadow);
- #endif
+#endif
return result;
}
@@ -467,19 +490,23 @@
return _custom_eval_frame(tstate, frame, throw_flag, callback);
}
+// NOTE: In 3.12+, any return NULL; statements must be preceded by
+// clear_old_frame_if_python_312_plus(tstate, frame); since the eval frame function
+// is now responsible for clearing/popping the frame.
+// eval_frame_default/eval_custom_code will clear/pop the frame.
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback) {
- #if IS_PYTHON_3_11_PLUS
+#if IS_PYTHON_3_11_PLUS
DEBUG_TRACE(
"begin %s %s %i %i",
get_frame_name(frame),
PyUnicode_AsUTF8(frame->f_code->co_filename),
frame->f_code->co_firstlineno,
_PyInterpreterFrame_LASTI(frame));
- #else
+#else
DEBUG_TRACE(
"begin %s %s %i %i %i",
get_frame_name(frame),
@@ -487,7 +514,7 @@
frame->f_lineno,
frame->f_lasti,
frame->f_iblock);
- #endif
+#endif
if (throw_flag) {
// When unwinding generators, eval frame is called with throw_flag ==
@@ -527,6 +554,7 @@
// TODO(jansel): investigate directly using the "fast" representation
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
DEBUG_TRACE("error %s", get_frame_name(frame));
+ clear_old_frame_if_python_312_plus(tstate, frame);
return NULL;
}
@@ -542,6 +570,7 @@
if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
+ clear_old_frame_if_python_312_plus(tstate, frame);
return NULL;
} else if (maybe_cached_code == Py_None) {
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
@@ -566,6 +595,7 @@
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
+ clear_old_frame_if_python_312_plus(tstate, frame);
return NULL;
} else if (maybe_cached_code != Py_None) {
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
@@ -588,6 +618,7 @@
// cascading failure from internal exceptions. The upshot is if
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
// inside the torch.compile block we won't try to Dynamo anything else.
+ clear_old_frame_if_python_312_plus(tstate, frame);
return NULL;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", get_frame_name(frame));
@@ -598,6 +629,7 @@
// extra->cache_entry. extra wont be NULL here.
CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
Py_DECREF(result);
+
// Update the existing cache_entry on the extra object. This extra object is
// sitting on the extra scratch space, we are just changing the cache_entry
// ptr. As a result, extra now becomes the owner of CacheEntry object. This
@@ -732,6 +764,9 @@
-1,
_methods};
+#if IS_PYTHON_3_12_PLUS
+#define _PyEval_RequestCodeExtraIndex PyUnstable_Eval_RequestCodeExtraIndex
+#endif
PyObject* torch_c_dynamo_eval_frame_init(void) {
extra_index = _PyEval_RequestCodeExtraIndex(destroy_extra_state);