| #include <torch/csrc/dynamo/extra_state.h> |
| |
| #include <torch/csrc/dynamo/cache_entry.h> |
| #include <torch/csrc/dynamo/debug_macros.h> |
| |
| Py_ssize_t extra_index = -1; |
| |
| CacheEntry* ExtraState::get_first_entry() { |
| if (this->cache_entry_list.empty()) { |
| return NULL; |
| } |
| return &this->cache_entry_list.front(); |
| } |
| |
| void ExtraState::move_to_front(CacheEntry* cache_entry) { |
| CHECK(cache_entry->_owner == this); |
| CHECK(!this->cache_entry_list.empty()); |
| CHECK(cache_entry == &*cache_entry->_owner_loc); |
| this->cache_entry_list.splice( |
| this->cache_entry_list.begin(), |
| this->cache_entry_list, |
| cache_entry->_owner_loc); |
| } |
| |
| CacheEntry* extract_cache_entry(ExtraState* extra_state) { |
| if (extra_state == NULL || extra_state == SKIP_CODE) { |
| return NULL; |
| } |
| return extra_state->get_first_entry(); |
| } |
| |
| FrameState* extract_frame_state(ExtraState* extra_state) { |
| if (extra_state == NULL || extra_state == SKIP_CODE) { |
| return NULL; |
| } |
| return (FrameState*)extra_state->frame_state.ptr(); |
| } |
| |
| ExtraState* get_extra_state(PyCodeObject* code) { |
| ExtraState* extra = NULL; |
| _PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra); |
| return extra; |
| } |
| |
| void destroy_extra_state(void* obj) { |
| ExtraState* extra = (ExtraState*)obj; |
| if (extra != NULL && extra != SKIP_CODE) { |
| delete extra; |
| } |
| } |
| |
| void set_extra_state(PyCodeObject* code, ExtraState* extra_state) { |
| ExtraState* old_extra_state = get_extra_state(code); |
| CHECK( |
| old_extra_state == NULL || old_extra_state == SKIP_CODE || |
| old_extra_state != extra_state); |
| _PyCode_SetExtra((PyObject*)code, extra_index, extra_state); |
| } |
| |
| ExtraState* init_and_set_extra_state(PyCodeObject* code) { |
| // Invariant - Extra state should not have been set before, therefore it |
| // should be NULL. |
| CHECK(get_extra_state(code) == NULL); |
| ExtraState* extra_state = new ExtraState(); |
| NULL_CHECK(extra_state); |
| set_extra_state(code, extra_state); |
| return extra_state; |
| } |
| |
| PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) { |
| size_t index = 0; |
| CacheEntry* found = nullptr; |
| py::handle locals(f_locals); |
| for (CacheEntry& cache_entry : extra_state->cache_entry_list) { |
| py::object valid = py::none(); |
| try { |
| valid = cache_entry.check_fn(locals); |
| } catch (py::error_already_set& e) { |
| if (guard_error_hook) { |
| py::handle guard_error_hook_handle(guard_error_hook); |
| guard_error_hook_handle( |
| cache_entry.check_fn, |
| cache_entry.code, |
| locals, |
| index, |
| index == extra_state->cache_entry_list.size() - 1); |
| } |
| // this function is called from C, so we cannot repropagate |
| // the exception |
| e.restore(); |
| return NULL; |
| } |
| if (valid.cast<bool>()) { |
| found = &cache_entry; |
| break; |
| } |
| ++index; |
| } |
| if (found) { |
| extra_state->move_to_front(found); |
| return found->code.ptr(); |
| } |
| return py::none().ptr(); |
| } |
| |
| CacheEntry* create_cache_entry( |
| ExtraState* extra_state, |
| PyObject* guarded_code) { |
| extra_state->cache_entry_list.emplace_front(guarded_code); |
| auto new_iter = extra_state->cache_entry_list.begin(); |
| new_iter->_owner = extra_state; |
| new_iter->_owner_loc = new_iter; |
| return &*new_iter; |
| } |
| |
| py::list _debug_get_cache_entry_list(const py::handle& code_obj) { |
| if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { |
| throw py::type_error("expected a code object!"); |
| } |
| PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); |
| ExtraState* extra = get_extra_state(code); |
| py::list result; |
| if (extra) { |
| for (CacheEntry& e : extra->cache_entry_list) { |
| result.append(py::cast(e, py::return_value_policy::reference)); |
| } |
| } |
| return result; |
| } |