blob: 99bbc0a2d0c0531f099dbc08af5b8dfe04295d78 [file] [log] [blame]
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/debug_macros.h>
Py_ssize_t extra_index = -1;
CacheEntry* ExtraState::get_first_entry() {
if (this->cache_entry_list.empty()) {
return NULL;
}
return &this->cache_entry_list.front();
}
void ExtraState::move_to_front(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.splice(
this->cache_entry_list.begin(),
this->cache_entry_list,
cache_entry->_owner_loc);
}
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return extra_state->get_first_entry();
}
FrameState* extract_frame_state(ExtraState* extra_state) {
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return (FrameState*)extra_state->frame_state.ptr();
}
ExtraState* get_extra_state(PyCodeObject* code) {
ExtraState* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
return extra;
}
void destroy_extra_state(void* obj) {
ExtraState* extra = (ExtraState*)obj;
if (extra != NULL && extra != SKIP_CODE) {
delete extra;
}
}
void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
ExtraState* old_extra_state = get_extra_state(code);
CHECK(
old_extra_state == NULL || old_extra_state == SKIP_CODE ||
old_extra_state != extra_state);
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
}
ExtraState* init_and_set_extra_state(PyCodeObject* code) {
// Invariant - Extra state should not have been set before, therefore it
// should be NULL.
CHECK(get_extra_state(code) == NULL);
ExtraState* extra_state = new ExtraState();
NULL_CHECK(extra_state);
set_extra_state(code, extra_state);
return extra_state;
}
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
py::object valid = py::none();
try {
valid = cache_entry.check_fn(locals);
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return NULL;
}
if (valid.cast<bool>()) {
found = &cache_entry;
break;
}
++index;
}
if (found) {
extra_state->move_to_front(found);
return found->code.ptr();
}
return py::none().ptr();
}
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code) {
extra_state->cache_entry_list.emplace_front(guarded_code);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
return &*new_iter;
}
py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
throw py::type_error("expected a code object!");
}
PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
ExtraState* extra = get_extra_state(code);
py::list result;
if (extra) {
for (CacheEntry& e : extra->cache_entry_list) {
result.append(py::cast(e, py::return_value_policy::reference));
}
}
return result;
}