[dynamo][guards] Move backend match to eval_frame (#121954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121954
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py
index 45f2a6c..1170010 100644
--- a/test/dynamo/test_comptime.py
+++ b/test/dynamo/test_comptime.py
@@ -223,13 +223,6 @@
'obj_weakref': None
'guarded_class': None
}
- global '' BACKEND_MATCH
- {
- 'guard_types': None,
- 'code': None,
- 'obj_weakref': None
- 'guarded_class': None
- }
shape_env '' SHAPE_ENV
{
'guard_types': None,
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index fee1680..7ff780d 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -4827,10 +4827,6 @@
opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
self.assertEqual(exp_out, opt_out)
self.assertEqual(cnt.frame_count, exp_frame_count)
- self.assertEqual(
- len(torch._dynamo.eval_frame.cached_backends),
- exp_n_cached_backend,
- )
def test_backend_match_guard(self):
x = torch.randn([3, 4])
@@ -4912,12 +4908,6 @@
for thread in threads:
thread.join()
- # Threads are sharing the backend cache. We see two cnt backends and one None backend
- self.assertEqual(
- len(torch._dynamo.eval_frame.cached_backends),
- 3,
- )
-
self.assertEqual(len(thread_success), len(threads))
def test_dynamo_min_operator_with_shape(self):
diff --git a/torch/__init__.py b/torch/__init__.py
index ffb8e61..476684e 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1685,6 +1685,9 @@
self.apply_mode(mode)
self.apply_options(options)
+ # Stash the compiler_fn to be used for backend match guard.
+ from torch._inductor.compile_fx import compile_fx
+ self.compiler_fn = compile_fx
if self.config.get("triton.cudagraphs", False):
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
# FIXME: CUDA Graph does not work well with CUPTI teardown.
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index f5a66c5..7c6a244 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -16,7 +16,6 @@
import os
import sys
import textwrap
-import threading
import traceback
import types
import warnings
@@ -77,82 +76,19 @@
token = 0
-unset = Unset.token
-
-guarded_backend_cache = threading.local()
cached_backends: Dict[int, CompilerFn] = {}
-
-def check_current_backend(backend_obj_id: int):
- """
- Called from guards to check if we need to recompile due to a backend change
- """
- # TODO(jansel): we should move guarded_backend_cache to C++
- try:
- if guarded_backend_cache.skip_backend_check_for_run_only_mode:
- return True
- except AttributeError:
- # Go slightly faster next time
- guarded_backend_cache.skip_backend_check_for_run_only_mode = False
- try:
- current_backend = guarded_backend_cache.current_backend
- except AttributeError:
- current_backend = None
- return (
- # Avoid the dict lookup in case of exact same object
- id(current_backend) == backend_obj_id
- or current_backend == cached_backends.get(backend_obj_id, None)
- )
+unset = Unset.token
def _reset_guarded_backend_cache():
global cached_backends
- guarded_backend_cache.skip_backend_check_for_run_only_mode = False
- guarded_backend_cache.current_backend = None
for backend in cached_backends.values():
if hasattr(backend, "reset"):
backend.reset()
cached_backends.clear()
-def backend_cache_manager(callback: DynamoCallback):
- # callback is False for RunOnlyContext. RunOnlyContext is used
- # as a way to re-use the previous compiled cache.
- # We therefore skip the check and re-use whatever code that's already cached.
- # Note: the cache that's actually used depends on the caching policy.
- if callback is False:
-
- def change():
- try:
- prev_skip = guarded_backend_cache.skip_backend_check_for_run_only_mode
- except AttributeError:
- prev_skip = False
- guarded_backend_cache.skip_backend_check_for_run_only_mode = True
-
- def revert():
- guarded_backend_cache.skip_backend_check_for_run_only_mode = prev_skip
-
- return revert
-
- else:
- backend = innermost_fn(callback)
-
- def change():
- cached_backends.setdefault(id(backend), backend)
- try:
- prev_backend = guarded_backend_cache.current_backend
- except AttributeError:
- prev_backend = None
- guarded_backend_cache.current_backend = backend
-
- def revert():
- guarded_backend_cache.current_backend = prev_backend
-
- return revert
-
- return change
-
-
DONT_WRAP_FILES = {
# For tracing into fx modules
inspect.getsourcefile(GraphModule),
@@ -306,9 +242,13 @@
self.export = export
self.compiler_config = compiler_config
self.cleanup_fns: List[Callable[[], Any]] = []
- self.enter_exit_hooks = [backend_cache_manager(self.callback)]
+ self.enter_exit_hooks = []
patch_fn()
+ # Save the backends so that we can reset them during torch._dynamo.reset
+ backend = innermost_fn(callback)
+ cached_backends.setdefault(id(backend), backend)
+
if dynamic is not None:
self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
@@ -672,6 +612,9 @@
dynamic=dynamic,
hooks=hooks,
)
+ # The backend function is stashed in the callable returned by
+ # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
+ # be used by eval_frame.c to insert a guard on the backend.
return _optimize_catch_errors(
convert_frame.convert_frame(backend, hooks=hooks),
hooks,
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 668203f..3f1009b 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -647,15 +647,6 @@
guard, [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"]
)
- def BACKEND_MATCH(self, guard: Guard):
- """Guard on backend matching based on id of current_backend"""
- assert guard.source is GuardSource.GLOBAL
- backend_id = (
- f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
- )
- code = [f"___check_current_backend({backend_id})"]
- self._produce_guard_code(guard, code)
-
def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
@@ -1203,7 +1194,6 @@
"___check_tensors": check_tensors_fn,
"___check_tensors_verbose": check_tensors_verbose_fn,
"___check_global_state": global_state.check,
- "___check_current_backend": torch._dynamo.eval_frame.check_current_backend,
"tensor_check_names": tensor_check_names,
**SYMPY_INTERP,
**CLOSURE_VARS,
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 42fa15a..a158df9 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -471,8 +471,6 @@
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
)
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
-
def synthetic_graph_input(self, fn, args):
"""
call fn(*args) before the graph runs and turn the result into a fake input.
diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp
index 34e2764..8a07f59 100644
--- a/torch/csrc/dynamo/cache_entry.cpp
+++ b/torch/csrc/dynamo/cache_entry.cpp
@@ -4,9 +4,10 @@
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
-CacheEntry::CacheEntry(const py::handle& guarded_code) {
+CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
this->check_fn = guarded_code.attr("check_fn");
this->code = guarded_code.attr("code");
+ this->backend = backend;
// TODO - clean this up when enable_cpp_guard_manager is True by default
if (py::hasattr(this->check_fn, "root")) {
this->root_mgr = convert_to_root_guard_manager(this->check_fn.attr("root"));
@@ -39,3 +40,14 @@
}
return py::cast(e, py::return_value_policy::reference).release().ptr();
}
+
+PyObject* get_backend(PyObject* callback) {
+ py::handle handle = py::handle(callback);
+ while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
+ handle = handle.attr("_torchdynamo_orig_callable");
+ }
+ if (py::hasattr(handle, "compiler_fn")) {
+ handle = handle.attr("compiler_fn");
+ }
+ return handle.ptr();
+}
diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h
index 216d12b..25d7e70 100644
--- a/torch/csrc/dynamo/cache_entry.h
+++ b/torch/csrc/dynamo/cache_entry.h
@@ -45,12 +45,14 @@
py::object code;
// root guard manager if exists
void* root_mgr{nullptr};
+ // backend used to create this cache entry
+ PyObject* backend{nullptr};
// Reference to owning ExtraState
ExtraState* _owner{nullptr};
// Reference to this CacheEntry's location in owner's linked list
std::list<CacheEntry>::iterator _owner_loc;
- CacheEntry(const py::handle& guarded_code);
+ CacheEntry(const py::handle& guarded_code, PyObject* backend);
~CacheEntry();
// Warning: returns a reference whose lifetime is controlled by C++
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index e1a63da..828e492 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -530,12 +530,14 @@
return NULL;
}
+ PyObject* backend = get_backend(callback);
+
// A callback of Py_False indicates "run only" mode, the cache is checked, but
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
- PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
+ PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
@@ -560,7 +562,7 @@
eval_frame_callback_set(Py_None);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
- PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
+ PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
@@ -594,7 +596,7 @@
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
// reference seems wrong. Therefore, we directly access the
// extra->cache_entry. extra wont be NULL here.
- CacheEntry* new_cache_entry = create_cache_entry(extra, result);
+ CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
Py_DECREF(result);
// Update the existing cache_entry on the extra object. This extra object is
// sitting on the extra scratch space, we are just changing the cache_entry
diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp
index c52fa64..f4b5de0 100644
--- a/torch/csrc/dynamo/extra_state.cpp
+++ b/torch/csrc/dynamo/extra_state.cpp
@@ -82,34 +82,40 @@
return extra_state;
}
-PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
+PyObject* lookup(
+ ExtraState* extra_state,
+ PyObject* f_locals,
+ PyObject* backend) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
- bool valid = false;
- try {
- // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is True
- // by default
- if (cache_entry.root_mgr != nullptr) {
- valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
- } else {
- valid = cache_entry.check_fn(locals).cast<bool>();
+ // Check backend. Py_False means run only mode.
+ bool valid = backend == Py_False || cache_entry.backend == backend;
+ if (valid) {
+ try {
+ // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
+ // True by default
+ if (cache_entry.root_mgr != nullptr) {
+ valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
+ } else {
+ valid = cache_entry.check_fn(locals).cast<bool>();
+ }
+ } catch (py::error_already_set& e) {
+ if (guard_error_hook) {
+ py::handle guard_error_hook_handle(guard_error_hook);
+ guard_error_hook_handle(
+ cache_entry.check_fn,
+ cache_entry.code,
+ locals,
+ index,
+ index == extra_state->cache_entry_list.size() - 1);
+ }
+ // this function is called from C, so we cannot repropagate
+ // the exception
+ e.restore();
+ return NULL;
}
- } catch (py::error_already_set& e) {
- if (guard_error_hook) {
- py::handle guard_error_hook_handle(guard_error_hook);
- guard_error_hook_handle(
- cache_entry.check_fn,
- cache_entry.code,
- locals,
- index,
- index == extra_state->cache_entry_list.size() - 1);
- }
- // this function is called from C, so we cannot repropagate
- // the exception
- e.restore();
- return NULL;
}
if (valid) {
found = &cache_entry;
@@ -126,8 +132,9 @@
CacheEntry* create_cache_entry(
ExtraState* extra_state,
- PyObject* guarded_code) {
- extra_state->cache_entry_list.emplace_front(guarded_code);
+ PyObject* guarded_code,
+ PyObject* backend) {
+ extra_state->cache_entry_list.emplace_front(guarded_code, backend);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h
index a8bfb33..fba5f33 100644
--- a/torch/csrc/dynamo/extra_state.h
+++ b/torch/csrc/dynamo/extra_state.h
@@ -124,7 +124,10 @@
// - f_locals: Borrowed
// return:
// - Py_None or PyCodeObject: Borrowed reference.
-PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
+PyObject* lookup(
+ ExtraState* extra_state,
+ PyObject* f_locals,
+ PyObject* callback);
// Create a new cache entry at extra_state holding on to guarded_code.
// Ownership contract
@@ -133,7 +136,13 @@
// - guarded_code: Borrowed
// return:
// - cache_entry: Borrowed reference
-CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
+CacheEntry* create_cache_entry(
+ ExtraState* extra_state,
+ PyObject* guraded_code,
+ PyObject* callback);
+
+// Extracts the backend fn from the callback.
+PyObject* get_backend(PyObject* callback);
#ifdef __cplusplus