[dynamo] "TorchDynamo Cache Lookup" event: use C++ api (#108436)
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.
**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.
**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.
<details>
<summary>Benchmarking script</summary>
```python
def fn(x, y):
return (x * y).relu()
a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]
opt_fn = torch.compile(fn)
opt_fn(a, b)
opt_fn(a, b)
with torch.profiler.profile() as prof:
opt_fn(a, b)
```
</details>
Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108436
Approved by: https://github.com/anijain2305, https://github.com/jansel
diff --git a/build_variables.bzl b/build_variables.bzl
index d814560..cafc617 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -829,6 +829,7 @@
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/dynamo/python_compiled_autograd.cpp",
+ "torch/csrc/dynamo/cpp_shim.cpp",
"torch/csrc/dynamo/cpython_defs.c",
"torch/csrc/dynamo/eval_frame.c",
"torch/csrc/dynamo/guards.cpp",
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 3f1fb02..fa74907 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -47,7 +47,6 @@
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx
-from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.nn import functional as F
from torch.testing._internal.common_cuda import (
@@ -2506,54 +2505,6 @@
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
self.assertTrue(result[1] == fn.__code__.co_lnotab)
- def test_profiler_cache_lookup(self):
- def fn(x):
- y = x**2
- y = y + 2
- z = y**3
- return z
-
- for profiler, get_events in (
- (torch.autograd.profiler.profile, lambda prof: prof.function_events),
- (torch.profiler.profiler.profile, lambda prof: prof.events()),
- ):
- x = torch.randn((2, 2), requires_grad=True)
- ref = fn(x)
- opt_fn = torch.compile(fn, backend="aot_eager")
-
- # warmup
- opt_fn(x)
-
- # whenver we enter the profiler context, hooks are automatically registered
- with profiler() as prof:
- res = opt_fn(x)
- events = list(
- filter(
- lambda event: event.name == "TorchDynamo Cache Lookup",
- get_events(prof),
- )
- )
-
- self.assertTrue(same(ref, res))
- self.assertTrue(
- len(events) == 1,
- "Expected one lookup profiler event for one opt_fn run",
- )
-
- with profiler() as prof:
- # just make sure the disable functionality works
- _enable_dynamo_cache_lookup_profiler(False)
- res = opt_fn(x)
- events = list(
- filter(
- lambda event: event.name == "TorchDynamo Cache Lookup",
- get_events(prof),
- )
- )
-
- self.assertTrue(same(ref, res))
- self.assertTrue(len(events) == 0, "Expected disabled profiling")
-
def test_tensor_is_contiguous(self):
def fn(x):
input = torch.randn((1, 16, 1, 1))
diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py
index 690f2b5..686d55b 100644
--- a/test/dynamo/test_profiler.py
+++ b/test/dynamo/test_profiler.py
@@ -7,6 +7,7 @@
import torch._dynamo.testing
import torch._dynamo.utils
+from torch._dynamo.testing import same
from torch._dynamo.utils import dynamo_timed
@@ -91,6 +92,39 @@
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs)
+ def test_profiler_cache_lookup(self):
+ def fn(x):
+ y = x**2
+ y = y + 2
+ z = y**3
+ return z
+
+ for profiler, get_events in (
+ (torch.autograd.profiler.profile, lambda prof: prof.function_events),
+ (torch.profiler.profiler.profile, lambda prof: prof.events()),
+ ):
+ x = torch.randn((2, 2), requires_grad=True)
+ ref = fn(x)
+ opt_fn = torch.compile(fn, backend="aot_eager")
+
+ # warmup
+ opt_fn(x)
+
+ with profiler() as prof:
+ res = opt_fn(x)
+ events = list(
+ filter(
+ lambda event: "TorchDynamo Cache Lookup" in event.name,
+ get_events(prof),
+ )
+ )
+
+ self.assertTrue(same(ref, res))
+ self.assertTrue(
+ len(events) == 1,
+ "Expected one lookup profiler event for one opt_fn run",
+ )
+
def test_profiler_cache_lookup_profiler_step(self):
def fn(x, y, z):
return torch.add(torch.sub(x, y), z)
diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi
index f3e064a..462cf1c 100644
--- a/torch/_C/_dynamo/eval_frame.pyi
+++ b/torch/_C/_dynamo/eval_frame.pyi
@@ -1,11 +1,6 @@
import types
-from torch._dynamo.types import (
- DynamoCallback,
- DynamoGuardHook,
- ProfilerEndHook,
- ProfilerStartHook,
-)
+from torch._dynamo.types import DynamoCallback, DynamoGuardHook
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def reset_code(code: types.CodeType) -> None: ...
@@ -13,5 +8,3 @@
def skip_code(code: types.CodeType) -> None: ...
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
-def set_profiler_hooks(start: ProfilerStartHook, end: ProfilerEndHook) -> None: ...
-def clear_profiler_hooks() -> None: ...
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index 650546f..3a09e5a 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -79,40 +79,12 @@
_is_profiler_enabled = enable
-def _enable_dynamo_cache_lookup_profiler(enable: bool):
- from torch._dynamo.eval_frame import ( # type: ignore[attr-defined]
- clear_profiler_hooks,
- set_profiler_hooks,
- )
-
- """
- Registers a hook within dynamo eval_frame.c called before and after
- the lookup process, which runs guards associated with each cached frame.
-
- Clear deregisters the hooks, saving overhead.
- """
-
- if enable:
-
- def _profiler_start(name):
- return torch.ops.profiler._record_function_enter_new(name, None)
-
- def _profiler_end(record):
- torch.ops.profiler._record_function_exit._RecordFunction(record)
-
- set_profiler_hooks(_profiler_start, _profiler_end)
- else:
- clear_profiler_hooks()
-
-
def _run_on_profiler_start():
_set_is_profiler_enabled(True)
- _enable_dynamo_cache_lookup_profiler(True)
def _run_on_profiler_stop():
_set_is_profiler_enabled(False)
- _enable_dynamo_cache_lookup_profiler(False)
class profile:
diff --git a/torch/csrc/dynamo/cpp_shim.cpp b/torch/csrc/dynamo/cpp_shim.cpp
new file mode 100644
index 0000000..35c415f
--- /dev/null
+++ b/torch/csrc/dynamo/cpp_shim.cpp
@@ -0,0 +1,22 @@
+#include <torch/csrc/dynamo/cpp_shim.h>
+
+#include <ATen/record_function.h>
+
+struct _PytorchRecordFunctionState {
+ at::RecordFunction guard;
+
+ _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
+};
+
+_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
+ _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
+ state->guard.before(name);
+ return state;
+}
+
+void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
+ if (state == nullptr) {
+ return;
+ }
+ delete state;
+}
diff --git a/torch/csrc/dynamo/cpp_shim.h b/torch/csrc/dynamo/cpp_shim.h
new file mode 100644
index 0000000..5baf678
--- /dev/null
+++ b/torch/csrc/dynamo/cpp_shim.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct _PytorchRecordFunctionState;
+typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState;
+
+_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name);
+void _pytorch_record_function_exit(_PytorchRecordFunctionState* state);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 28526fd..e693e74 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -1,4 +1,5 @@
#define PY_SSIZE_T_CLEAN
+#include <torch/csrc/dynamo/cpp_shim.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/utils/python_compat.h>
#include <opcode.h>
@@ -179,9 +180,7 @@
bool is_dynamo_compiling = false;
static PyObject* guard_fail_hook = NULL;
static PyObject* guard_error_hook = NULL;
-static PyObject* profiler_start_hook = NULL;
-static PyObject* profiler_end_hook = NULL;
-static PyObject* guard_profiler_name_str = NULL; /* cached py str */
+const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
// Points to the extra scratch space on the code object
static Py_ssize_t extra_index = -1;
@@ -645,22 +644,6 @@
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
}
-static PyObject* call_profiler_start_hook(PyObject* name_str) {
- if (profiler_start_hook == NULL) return NULL;
- return PyObject_CallOneArg(profiler_start_hook, name_str);
-}
-
-static void call_profiler_end_hook(PyObject* record) {
- // 'record' obj is the return value of calling _start_hook()
- if (profiler_end_hook == NULL || record == NULL) return;
- PyObject* res = PyObject_CallOneArg(profiler_end_hook, record);
- if (res == NULL) {
- PyErr_WriteUnraisable(profiler_end_hook);
- return;
- }
- Py_DECREF(res);
-}
-
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
@@ -939,10 +922,9 @@
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
- PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
+ _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
- call_profiler_end_hook(hook_record);
- Py_XDECREF(hook_record);
+ _pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
@@ -965,10 +947,9 @@
// in the shim.
eval_frame_callback_set(Py_None);
- PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
+ _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
- call_profiler_end_hook(hook_record);
- Py_XDECREF(hook_record);
+ _pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
return NULL;
@@ -1131,27 +1112,6 @@
Py_RETURN_NONE;
}
-static PyObject* clear_profiler_hooks(PyObject* module, PyObject* unused) {
- Py_CLEAR(profiler_start_hook);
- Py_CLEAR(profiler_end_hook);
- Py_RETURN_NONE;
-}
-
-static PyObject* set_profiler_hooks(PyObject* module, PyObject* args) {
- PyObject* start = NULL;
- PyObject* end = NULL;
- if (!PyArg_ParseTuple(args, "OO:set_profiler_hooks", &start, &end)) {
- return NULL;
- }
- if (start == Py_None || end == Py_None) {
- clear_profiler_hooks(module, NULL);
- } else {
- Py_XSETREF(profiler_start_hook, Py_NewRef(start));
- Py_XSETREF(profiler_end_hook, Py_NewRef(end));
- }
- Py_RETURN_NONE;
-}
-
static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame_py, METH_O, NULL},
{"reset_code", reset_code, METH_O, NULL},
@@ -1159,8 +1119,6 @@
{"skip_code", skip_code, METH_O, NULL},
{"set_guard_fail_hook", set_guard_fail_hook, METH_O, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
- {"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
- {"clear_profiler_hooks", clear_profiler_hooks, METH_NOARGS, NULL},
{"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
@@ -1180,11 +1138,6 @@
return NULL;
}
- guard_profiler_name_str = PyUnicode_FromString("TorchDynamo Cache Lookup");
- if (guard_profiler_name_str == NULL) {
- return NULL;
- }
-
int result = PyThread_tss_create(&eval_frame_callback_key);
CHECK(result == 0);