[dynamo][guards] Move backend match to eval_frame (#121954)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121954
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py
index 45f2a6c..1170010 100644
--- a/test/dynamo/test_comptime.py
+++ b/test/dynamo/test_comptime.py
@@ -223,13 +223,6 @@
             'obj_weakref': None
             'guarded_class': None
         }
-        global '' BACKEND_MATCH
-        {
-            'guard_types': None,
-            'code': None,
-            'obj_weakref': None
-            'guarded_class': None
-        }
         shape_env '' SHAPE_ENV
         {
             'guard_types': None,
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index fee1680..7ff780d 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -4827,10 +4827,6 @@
         opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
         self.assertEqual(exp_out, opt_out)
         self.assertEqual(cnt.frame_count, exp_frame_count)
-        self.assertEqual(
-            len(torch._dynamo.eval_frame.cached_backends),
-            exp_n_cached_backend,
-        )
 
     def test_backend_match_guard(self):
         x = torch.randn([3, 4])
@@ -4912,12 +4908,6 @@
         for thread in threads:
             thread.join()
 
-        # Threads are sharing the backend cache. We see two cnt backends and one None backend
-        self.assertEqual(
-            len(torch._dynamo.eval_frame.cached_backends),
-            3,
-        )
-
         self.assertEqual(len(thread_success), len(threads))
 
     def test_dynamo_min_operator_with_shape(self):
diff --git a/torch/__init__.py b/torch/__init__.py
index ffb8e61..476684e 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1685,6 +1685,9 @@
         self.apply_mode(mode)
         self.apply_options(options)
 
+        # Stash the compiler_fn to be used for backend match guard.
+        from torch._inductor.compile_fx import compile_fx
+        self.compiler_fn = compile_fx
         if self.config.get("triton.cudagraphs", False):
             os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
             # FIXME: CUDA Graph does not work well with CUPTI teardown.
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index f5a66c5..7c6a244 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -16,7 +16,6 @@
 import os
 import sys
 import textwrap
-import threading
 import traceback
 import types
 import warnings
@@ -77,82 +76,19 @@
     token = 0
 
 
-unset = Unset.token
-
-guarded_backend_cache = threading.local()
 cached_backends: Dict[int, CompilerFn] = {}
 
-
-def check_current_backend(backend_obj_id: int):
-    """
-    Called from guards to check if we need to recompile due to a backend change
-    """
-    # TODO(jansel): we should move guarded_backend_cache to C++
-    try:
-        if guarded_backend_cache.skip_backend_check_for_run_only_mode:
-            return True
-    except AttributeError:
-        # Go slightly faster next time
-        guarded_backend_cache.skip_backend_check_for_run_only_mode = False
-    try:
-        current_backend = guarded_backend_cache.current_backend
-    except AttributeError:
-        current_backend = None
-    return (
-        # Avoid the dict lookup in case of exact same object
-        id(current_backend) == backend_obj_id
-        or current_backend == cached_backends.get(backend_obj_id, None)
-    )
+unset = Unset.token
 
 
 def _reset_guarded_backend_cache():
     global cached_backends
-    guarded_backend_cache.skip_backend_check_for_run_only_mode = False
-    guarded_backend_cache.current_backend = None
     for backend in cached_backends.values():
         if hasattr(backend, "reset"):
             backend.reset()
     cached_backends.clear()
 
 
-def backend_cache_manager(callback: DynamoCallback):
-    # callback is False for RunOnlyContext. RunOnlyContext is used
-    # as a way to re-use the previous compiled cache.
-    # We therefore skip the check and re-use whatever code that's already cached.
-    # Note: the cache that's actually used depends on the caching policy.
-    if callback is False:
-
-        def change():
-            try:
-                prev_skip = guarded_backend_cache.skip_backend_check_for_run_only_mode
-            except AttributeError:
-                prev_skip = False
-            guarded_backend_cache.skip_backend_check_for_run_only_mode = True
-
-            def revert():
-                guarded_backend_cache.skip_backend_check_for_run_only_mode = prev_skip
-
-            return revert
-
-    else:
-        backend = innermost_fn(callback)
-
-        def change():
-            cached_backends.setdefault(id(backend), backend)
-            try:
-                prev_backend = guarded_backend_cache.current_backend
-            except AttributeError:
-                prev_backend = None
-            guarded_backend_cache.current_backend = backend
-
-            def revert():
-                guarded_backend_cache.current_backend = prev_backend
-
-            return revert
-
-    return change
-
-
 DONT_WRAP_FILES = {
     # For tracing into fx modules
     inspect.getsourcefile(GraphModule),
@@ -306,9 +242,13 @@
         self.export = export
         self.compiler_config = compiler_config
         self.cleanup_fns: List[Callable[[], Any]] = []
-        self.enter_exit_hooks = [backend_cache_manager(self.callback)]
+        self.enter_exit_hooks = []
         patch_fn()
 
+        # Save the backends so that we can reset them during torch._dynamo.reset
+        backend = innermost_fn(callback)
+        cached_backends.setdefault(id(backend), backend)
+
         if dynamic is not None:
             self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
 
@@ -672,6 +612,9 @@
             dynamic=dynamic,
             hooks=hooks,
         )
+    # The backend function is stashed in the callable returned by
+    # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
+    # be used by eval_frame.c to insert a guard on the backend.
     return _optimize_catch_errors(
         convert_frame.convert_frame(backend, hooks=hooks),
         hooks,
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 668203f..3f1009b 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -647,15 +647,6 @@
             guard, [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"]
         )
 
-    def BACKEND_MATCH(self, guard: Guard):
-        """Guard on backend matching based on id of current_backend"""
-        assert guard.source is GuardSource.GLOBAL
-        backend_id = (
-            f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
-        )
-        code = [f"___check_current_backend({backend_id})"]
-        self._produce_guard_code(guard, code)
-
     def SHAPE_ENV(self, guard: Guard):
         # Let's handle ShapeEnv guards.  To do this, we will resolve
         # shape variables to sources from tracked_fakes.  This must happen after
@@ -1203,7 +1194,6 @@
             "___check_tensors": check_tensors_fn,
             "___check_tensors_verbose": check_tensors_verbose_fn,
             "___check_global_state": global_state.check,
-            "___check_current_backend": torch._dynamo.eval_frame.check_current_backend,
             "tensor_check_names": tensor_check_names,
             **SYMPY_INTERP,
             **CLOSURE_VARS,
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 42fa15a..a158df9 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -471,8 +471,6 @@
             GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
         )
 
-        self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
-
     def synthetic_graph_input(self, fn, args):
         """
         call fn(*args) before the graph runs and turn the result into a fake input.
diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp
index 34e2764..8a07f59 100644
--- a/torch/csrc/dynamo/cache_entry.cpp
+++ b/torch/csrc/dynamo/cache_entry.cpp
@@ -4,9 +4,10 @@
 #include <torch/csrc/dynamo/debug_macros.h>
 #include <torch/csrc/dynamo/extra_state.h>
 
-CacheEntry::CacheEntry(const py::handle& guarded_code) {
+CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
   this->check_fn = guarded_code.attr("check_fn");
   this->code = guarded_code.attr("code");
+  this->backend = backend;
   // TODO - clean this up when enable_cpp_guard_manager is True by default
   if (py::hasattr(this->check_fn, "root")) {
     this->root_mgr = convert_to_root_guard_manager(this->check_fn.attr("root"));
@@ -39,3 +40,14 @@
   }
   return py::cast(e, py::return_value_policy::reference).release().ptr();
 }
+
+PyObject* get_backend(PyObject* callback) {
+  py::handle handle = py::handle(callback);
+  while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
+    handle = handle.attr("_torchdynamo_orig_callable");
+  }
+  if (py::hasattr(handle, "compiler_fn")) {
+    handle = handle.attr("compiler_fn");
+  }
+  return handle.ptr();
+}
diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h
index 216d12b..25d7e70 100644
--- a/torch/csrc/dynamo/cache_entry.h
+++ b/torch/csrc/dynamo/cache_entry.h
@@ -45,12 +45,14 @@
   py::object code;
   // root guard manager if exists
   void* root_mgr{nullptr};
+  // backend used to create this cache entry
+  PyObject* backend{nullptr};
   // Reference to owning ExtraState
   ExtraState* _owner{nullptr};
   // Reference to this CacheEntry's location in owner's linked list
   std::list<CacheEntry>::iterator _owner_loc;
 
-  CacheEntry(const py::handle& guarded_code);
+  CacheEntry(const py::handle& guarded_code, PyObject* backend);
   ~CacheEntry();
 
   // Warning: returns a reference whose lifetime is controlled by C++
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index e1a63da..828e492 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -530,12 +530,14 @@
     return NULL;
   }
 
+  PyObject* backend = get_backend(callback);
+
   // A callback of Py_False indicates "run only" mode, the cache is checked, but
   // we never compile.
   if (callback == Py_False) {
     DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
     _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
-    PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
+    PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
     _pytorch_record_function_exit(rf);
 
     if (maybe_cached_code == NULL) {
@@ -560,7 +562,7 @@
   eval_frame_callback_set(Py_None);
 
   _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
-  PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
+  PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
   _pytorch_record_function_exit(rf);
   if (maybe_cached_code == NULL) {
     // Python error
@@ -594,7 +596,7 @@
     // extract_cache_entry returns a borrowed reference. Modifying a borrowed
     // reference seems wrong. Therefore, we directly access the
     // extra->cache_entry. extra wont be NULL here.
-    CacheEntry* new_cache_entry = create_cache_entry(extra, result);
+    CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
     Py_DECREF(result);
     // Update the existing cache_entry on the extra object. This extra object is
     // sitting on the extra scratch space, we are just changing the cache_entry
diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp
index c52fa64..f4b5de0 100644
--- a/torch/csrc/dynamo/extra_state.cpp
+++ b/torch/csrc/dynamo/extra_state.cpp
@@ -82,34 +82,40 @@
   return extra_state;
 }
 
-PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
+PyObject* lookup(
+    ExtraState* extra_state,
+    PyObject* f_locals,
+    PyObject* backend) {
   size_t index = 0;
   CacheEntry* found = nullptr;
   py::handle locals(f_locals);
   for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
-    bool valid = false;
-    try {
-      // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is True
-      // by default
-      if (cache_entry.root_mgr != nullptr) {
-        valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
-      } else {
-        valid = cache_entry.check_fn(locals).cast<bool>();
+    // Check backend. Py_False means run only mode.
+    bool valid = backend == Py_False || cache_entry.backend == backend;
+    if (valid) {
+      try {
+        // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
+        // True by default
+        if (cache_entry.root_mgr != nullptr) {
+          valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
+        } else {
+          valid = cache_entry.check_fn(locals).cast<bool>();
+        }
+      } catch (py::error_already_set& e) {
+        if (guard_error_hook) {
+          py::handle guard_error_hook_handle(guard_error_hook);
+          guard_error_hook_handle(
+              cache_entry.check_fn,
+              cache_entry.code,
+              locals,
+              index,
+              index == extra_state->cache_entry_list.size() - 1);
+        }
+        // this function is called from C, so we cannot repropagate
+        // the exception
+        e.restore();
+        return NULL;
       }
-    } catch (py::error_already_set& e) {
-      if (guard_error_hook) {
-        py::handle guard_error_hook_handle(guard_error_hook);
-        guard_error_hook_handle(
-            cache_entry.check_fn,
-            cache_entry.code,
-            locals,
-            index,
-            index == extra_state->cache_entry_list.size() - 1);
-      }
-      // this function is called from C, so we cannot repropagate
-      // the exception
-      e.restore();
-      return NULL;
     }
     if (valid) {
       found = &cache_entry;
@@ -126,8 +132,9 @@
 
 CacheEntry* create_cache_entry(
     ExtraState* extra_state,
-    PyObject* guarded_code) {
-  extra_state->cache_entry_list.emplace_front(guarded_code);
+    PyObject* guarded_code,
+    PyObject* backend) {
+  extra_state->cache_entry_list.emplace_front(guarded_code, backend);
   auto new_iter = extra_state->cache_entry_list.begin();
   new_iter->_owner = extra_state;
   new_iter->_owner_loc = new_iter;
diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h
index a8bfb33..fba5f33 100644
--- a/torch/csrc/dynamo/extra_state.h
+++ b/torch/csrc/dynamo/extra_state.h
@@ -124,7 +124,10 @@
 //  - f_locals: Borrowed
 // return:
 //   - Py_None or PyCodeObject: Borrowed reference.
-PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
+PyObject* lookup(
+    ExtraState* extra_state,
+    PyObject* f_locals,
+    PyObject* callback);
 
 // Create a new cache entry at extra_state holding on to guarded_code.
 // Ownership contract
@@ -133,7 +136,13 @@
 //  - guarded_code: Borrowed
 // return:
 //  - cache_entry: Borrowed reference
-CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
+CacheEntry* create_cache_entry(
+    ExtraState* extra_state,
+    PyObject* guraded_code,
+    PyObject* callback);
+
+// Extracts the backend fn from the callback.
+PyObject* get_backend(PyObject* callback);
 
 #ifdef __cplusplus