[dynamo] "TorchDynamo Cache Lookup" event: use C++ api (#108436)

**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.

**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.

**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.

<details>

<summary>Benchmarking script</summary>

```python
def fn(x, y):
    return (x * y).relu()

a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]

opt_fn = torch.compile(fn)

opt_fn(a, b)
opt_fn(a, b)

with torch.profiler.profile() as prof:
    opt_fn(a, b)
```

</details>

Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108436
Approved by: https://github.com/anijain2305, https://github.com/jansel
diff --git a/build_variables.bzl b/build_variables.bzl
index d814560..cafc617 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -829,6 +829,7 @@
     "torch/csrc/autograd/python_variable.cpp",
     "torch/csrc/autograd/python_variable_indexing.cpp",
     "torch/csrc/dynamo/python_compiled_autograd.cpp",
+    "torch/csrc/dynamo/cpp_shim.cpp",
     "torch/csrc/dynamo/cpython_defs.c",
     "torch/csrc/dynamo/eval_frame.c",
     "torch/csrc/dynamo/guards.cpp",
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 3f1fb02..fa74907 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -47,7 +47,6 @@
 from torch.ao.quantization.fake_quantize import FakeQuantize
 from torch.ao.quantization.qconfig import QConfig
 from torch.ao.quantization.quantize_fx import prepare_qat_fx
-from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
 from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
 from torch.nn import functional as F
 from torch.testing._internal.common_cuda import (
@@ -2506,54 +2505,6 @@
         result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
         self.assertTrue(result[1] == fn.__code__.co_lnotab)
 
-    def test_profiler_cache_lookup(self):
-        def fn(x):
-            y = x**2
-            y = y + 2
-            z = y**3
-            return z
-
-        for profiler, get_events in (
-            (torch.autograd.profiler.profile, lambda prof: prof.function_events),
-            (torch.profiler.profiler.profile, lambda prof: prof.events()),
-        ):
-            x = torch.randn((2, 2), requires_grad=True)
-            ref = fn(x)
-            opt_fn = torch.compile(fn, backend="aot_eager")
-
-            # warmup
-            opt_fn(x)
-
-            # whenver we enter the profiler context, hooks are automatically registered
-            with profiler() as prof:
-                res = opt_fn(x)
-            events = list(
-                filter(
-                    lambda event: event.name == "TorchDynamo Cache Lookup",
-                    get_events(prof),
-                )
-            )
-
-            self.assertTrue(same(ref, res))
-            self.assertTrue(
-                len(events) == 1,
-                "Expected one lookup profiler event for one opt_fn run",
-            )
-
-            with profiler() as prof:
-                # just make sure the disable functionality works
-                _enable_dynamo_cache_lookup_profiler(False)
-                res = opt_fn(x)
-            events = list(
-                filter(
-                    lambda event: event.name == "TorchDynamo Cache Lookup",
-                    get_events(prof),
-                )
-            )
-
-            self.assertTrue(same(ref, res))
-            self.assertTrue(len(events) == 0, "Expected disabled profiling")
-
     def test_tensor_is_contiguous(self):
         def fn(x):
             input = torch.randn((1, 16, 1, 1))
diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py
index 690f2b5..686d55b 100644
--- a/test/dynamo/test_profiler.py
+++ b/test/dynamo/test_profiler.py
@@ -7,6 +7,7 @@
 import torch._dynamo.testing
 import torch._dynamo.utils
 
+from torch._dynamo.testing import same
 from torch._dynamo.utils import dynamo_timed
 
 
@@ -91,6 +92,39 @@
         with torch.profiler.profile(record_shapes=True):
             opt_fn(*inputs)
 
+    def test_profiler_cache_lookup(self):
+        def fn(x):
+            y = x**2
+            y = y + 2
+            z = y**3
+            return z
+
+        for profiler, get_events in (
+            (torch.autograd.profiler.profile, lambda prof: prof.function_events),
+            (torch.profiler.profiler.profile, lambda prof: prof.events()),
+        ):
+            x = torch.randn((2, 2), requires_grad=True)
+            ref = fn(x)
+            opt_fn = torch.compile(fn, backend="aot_eager")
+
+            # warmup
+            opt_fn(x)
+
+            with profiler() as prof:
+                res = opt_fn(x)
+            events = list(
+                filter(
+                    lambda event: "TorchDynamo Cache Lookup" in event.name,
+                    get_events(prof),
+                )
+            )
+
+            self.assertTrue(same(ref, res))
+            self.assertTrue(
+                len(events) == 1,
+                "Expected one lookup profiler event for one opt_fn run",
+            )
+
     def test_profiler_cache_lookup_profiler_step(self):
         def fn(x, y, z):
             return torch.add(torch.sub(x, y), z)
diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi
index f3e064a..462cf1c 100644
--- a/torch/_C/_dynamo/eval_frame.pyi
+++ b/torch/_C/_dynamo/eval_frame.pyi
@@ -1,11 +1,6 @@
 import types
 
-from torch._dynamo.types import (
-    DynamoCallback,
-    DynamoGuardHook,
-    ProfilerEndHook,
-    ProfilerStartHook,
-)
+from torch._dynamo.types import DynamoCallback, DynamoGuardHook
 
 def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
 def reset_code(code: types.CodeType) -> None: ...
@@ -13,5 +8,3 @@
 def skip_code(code: types.CodeType) -> None: ...
 def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
 def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
-def set_profiler_hooks(start: ProfilerStartHook, end: ProfilerEndHook) -> None: ...
-def clear_profiler_hooks() -> None: ...
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index 650546f..3a09e5a 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -79,40 +79,12 @@
     _is_profiler_enabled = enable
 
 
-def _enable_dynamo_cache_lookup_profiler(enable: bool):
-    from torch._dynamo.eval_frame import (  # type: ignore[attr-defined]
-        clear_profiler_hooks,
-        set_profiler_hooks,
-    )
-
-    """
-    Registers a hook within dynamo eval_frame.c called before and after
-    the lookup process, which runs guards associated with each cached frame.
-
-    Clear deregisters the hooks, saving overhead.
-    """
-
-    if enable:
-
-        def _profiler_start(name):
-            return torch.ops.profiler._record_function_enter_new(name, None)
-
-        def _profiler_end(record):
-            torch.ops.profiler._record_function_exit._RecordFunction(record)
-
-        set_profiler_hooks(_profiler_start, _profiler_end)
-    else:
-        clear_profiler_hooks()
-
-
 def _run_on_profiler_start():
     _set_is_profiler_enabled(True)
-    _enable_dynamo_cache_lookup_profiler(True)
 
 
 def _run_on_profiler_stop():
     _set_is_profiler_enabled(False)
-    _enable_dynamo_cache_lookup_profiler(False)
 
 
 class profile:
diff --git a/torch/csrc/dynamo/cpp_shim.cpp b/torch/csrc/dynamo/cpp_shim.cpp
new file mode 100644
index 0000000..35c415f
--- /dev/null
+++ b/torch/csrc/dynamo/cpp_shim.cpp
@@ -0,0 +1,22 @@
+#include <torch/csrc/dynamo/cpp_shim.h>
+
+#include <ATen/record_function.h>
+
+struct _PytorchRecordFunctionState {
+  at::RecordFunction guard;
+
+  _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
+};
+
+_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
+  _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
+  state->guard.before(name);
+  return state;
+}
+
+void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
+  if (state == nullptr) {
+    return;
+  }
+  delete state;
+}
diff --git a/torch/csrc/dynamo/cpp_shim.h b/torch/csrc/dynamo/cpp_shim.h
new file mode 100644
index 0000000..5baf678
--- /dev/null
+++ b/torch/csrc/dynamo/cpp_shim.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct _PytorchRecordFunctionState;
+typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState;
+
+_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name);
+void _pytorch_record_function_exit(_PytorchRecordFunctionState* state);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 28526fd..e693e74 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -1,4 +1,5 @@
 #define PY_SSIZE_T_CLEAN
+#include <torch/csrc/dynamo/cpp_shim.h>
 #include <torch/csrc/dynamo/cpython_defs.h>
 #include <torch/csrc/utils/python_compat.h>
 #include <opcode.h>
@@ -179,9 +180,7 @@
 bool is_dynamo_compiling = false;
 static PyObject* guard_fail_hook = NULL;
 static PyObject* guard_error_hook = NULL;
-static PyObject* profiler_start_hook = NULL;
-static PyObject* profiler_end_hook = NULL;
-static PyObject* guard_profiler_name_str = NULL; /* cached py str */
+const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
 
 // Points to the extra scratch space on the code object
 static Py_ssize_t extra_index = -1;
@@ -645,22 +644,6 @@
       (e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
 }
 
-static PyObject* call_profiler_start_hook(PyObject* name_str) {
-  if (profiler_start_hook == NULL) return NULL;
-  return PyObject_CallOneArg(profiler_start_hook, name_str);
-}
-
-static void call_profiler_end_hook(PyObject* record) {
-  // 'record' obj is the return value of calling _start_hook()
-  if (profiler_end_hook == NULL || record == NULL) return;
-  PyObject* res = PyObject_CallOneArg(profiler_end_hook, record);
-  if (res == NULL) {
-    PyErr_WriteUnraisable(profiler_end_hook);
-    return;
-  }
-  Py_DECREF(res);
-}
-
 // Return value: borrowed reference
 // Is either Py_None or a PyCodeObject
 static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
@@ -939,10 +922,9 @@
   // we never compile.
   if (callback == Py_False) {
     DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
-    PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
+    _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
     PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
-    call_profiler_end_hook(hook_record);
-    Py_XDECREF(hook_record);
+    _pytorch_record_function_exit(rf);
 
     if (maybe_cached_code == NULL) {
       // guard eval failed, keep propagating
@@ -965,10 +947,9 @@
   // in the shim.
   eval_frame_callback_set(Py_None);
 
-  PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
+  _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
   PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
-  call_profiler_end_hook(hook_record);
-  Py_XDECREF(hook_record);
+  _pytorch_record_function_exit(rf);
   if (maybe_cached_code == NULL) {
     // Python error
     return NULL;
@@ -1131,27 +1112,6 @@
   Py_RETURN_NONE;
 }
 
-static PyObject* clear_profiler_hooks(PyObject* module, PyObject* unused) {
-  Py_CLEAR(profiler_start_hook);
-  Py_CLEAR(profiler_end_hook);
-  Py_RETURN_NONE;
-}
-
-static PyObject* set_profiler_hooks(PyObject* module, PyObject* args) {
-  PyObject* start = NULL;
-  PyObject* end = NULL;
-  if (!PyArg_ParseTuple(args, "OO:set_profiler_hooks", &start, &end)) {
-    return NULL;
-  }
-  if (start == Py_None || end == Py_None) {
-    clear_profiler_hooks(module, NULL);
-  } else {
-    Py_XSETREF(profiler_start_hook, Py_NewRef(start));
-    Py_XSETREF(profiler_end_hook, Py_NewRef(end));
-  }
-  Py_RETURN_NONE;
-}
-
 static PyMethodDef _methods[] = {
     {"set_eval_frame", set_eval_frame_py, METH_O, NULL},
     {"reset_code", reset_code, METH_O, NULL},
@@ -1159,8 +1119,6 @@
     {"skip_code", skip_code, METH_O, NULL},
     {"set_guard_fail_hook", set_guard_fail_hook, METH_O, NULL},
     {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
-    {"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
-    {"clear_profiler_hooks", clear_profiler_hooks, METH_NOARGS, NULL},
     {"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
     {NULL, NULL, 0, NULL}};
 
@@ -1180,11 +1138,6 @@
     return NULL;
   }
 
-  guard_profiler_name_str = PyUnicode_FromString("TorchDynamo Cache Lookup");
-  if (guard_profiler_name_str == NULL) {
-    return NULL;
-  }
-
   int result = PyThread_tss_create(&eval_frame_callback_key);
   CHECK(result == 0);