Make RecordFunction callbacks thread local and modernize interface (#37491)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37491
This PR modernizes RecordFunction API and adds thread local callbacks
in addition to the global ones
Changes:
- support for TLS callbacks, this is going to be the foundation of profiler and other tools
- modernize interface around simple set of functions (add|remove|has|clear)(Global|ThreadLocal)(Callback) and adding RecordFunctionCallback to easily construct callbacks to be passed
- we also add `.setShouldRun` into the callback interface to support cases when simple uniform sampling is not enough
- to properly support add/remove introduce the idea of callback handle returned by add
- internal implementation still uses SmallVector to store intermediate state (as before) - in this case these are vector of handles of callbacks that were picked to run
- to speed up runtime we keep these vectors sorted, this way we can quickly enumerate callbacks that need to be run
- added tests for new functionality
Test Plan:
BUILD_BINARY=1 USE_BLAS=MKL USE_MKLDNN=0 USE_CUDA=0 python setup.py
develop install
./build/bin/test_jit
CI
record_function_benchmark: https://gist.github.com/ilia-cher/f1e094dae47fe23e55e7672ac4dcda2f
Imported from OSS
Differential Revision: D21300448
fbshipit-source-id: 6d55c26dbf20b33d35c3f1604dcc07bb063c8c43
diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc
index b271466..34a21ad 100644
--- a/binaries/record_function_benchmark.cc
+++ b/binaries/record_function_benchmark.cc
@@ -22,22 +22,22 @@
void setupCallbacks() {
// non-sampled callback
- profiler::pushCallback(
+ profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[&](const profiler::RecordFunction& fn) {
return true;
},
- [](const profiler::RecordFunction&) {},
- /* needs_inputs */ true);
+ [](const profiler::RecordFunction&) {})
+ .needsInputs(true));
// sampled
for (auto idx = 0; idx < kNumSampledCb; ++idx) {
- profiler::pushCallback(
+ profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[](const profiler::RecordFunction& fn) {
return true;
},
- [](const profiler::RecordFunction&) {},
- /* needs_inputs */ true,
- /* sampling_prob */ kSampingProb
+ [](const profiler::RecordFunction&) {})
+ .needsInputs(true)
+ .samplingProb(kSampingProb)
);
}
}
diff --git a/docs/source/notes/large_scale_deployments.rst b/docs/source/notes/large_scale_deployments.rst
index 509c8fe..4c96d80 100644
--- a/docs/source/notes/large_scale_deployments.rst
+++ b/docs/source/notes/large_scale_deployments.rst
@@ -26,7 +26,7 @@
across the entire set of machines.
New callbacks for any operator invocation can be added with
-``torch::autograd::profiler::pushCallback``. Hooks will be called with
+``torch::autograd::profiler::addGlobalCallback``. Hooks will be called with
``torch::autograd::profiler::RecordFunction`` struct that describes invocation
context (e.g. `name`). If enabled, ``RecordFunction::inputs()`` contains arguments
of the function represented as ``torch::IValue`` variant type. Note, that inputs
@@ -42,9 +42,9 @@
Invoking callbacks adds some overhead, so usually it's useful to just randomly
sample operator invocations. This can be enabled on per-callback basis with an
-optional sampling rate passed into ``torch::autograd::profiler::pushCallback``.
+optional sampling rate passed into ``torch::autograd::profiler::addGlobalCallback``.
-Note, that ``pushCallback`` is not thread-safe and can be called only when no
+Note, that ``addGlobalCallback`` is not thread-safe and can be called only when no
PyTorch operator is running. Usually, it's a good idea to call them once during
initialization.
@@ -55,22 +55,20 @@
// Called somewhere in the program beginning
void init() {
// Sample one in a hundred operator runs randomly
- pushCallback(
+ addGlobalCallback(
+ RecordFunctionCallback(
&onFunctionEnter,
- &onFunctionExit,
- /* needs_inputs */ true,
- /* sampling_prob */ 0.01
+ &onFunctionExit)
+ .needsInputs(true)
+ .samplingProb(0.01)
);
// Note, to enable observers in the model calling thread,
- // call enableObservers() in the thread before running a model
+ // call enableRecordFunction() in the thread before running a model
}
- bool onFunctionEnter(const RecordFunction& fn) {
+ void onFunctionEnter(const RecordFunction& fn) {
std::cerr << "Before function " << fn.name()
<< " with " << fn.inputs().size() << " inputs" << std::endl;
- // Returning false would mean that the callback is not interested
- // in this RecordFunction and onFunctionExit won't be called
- return true;
}
void onFunctionExit(const RecordFunction& fn) {
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index 793b958..990dda4 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -67,6 +67,8 @@
#include <utility>
#include <vector>
+using namespace torch::autograd::profiler;
+
namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
@@ -755,17 +757,11 @@
using namespace torch::autograd;
-void cleanUpScopeCallbacks() {
- while (profiler::hasCallbacks()) {
- profiler::popCallback();
- }
-}
-
void checkScopeCallbacks() {
bool found_function_scope = false;
bool found_method_scope = false;
bool found_user_scope = false;
- profiler::pushCallback(
+ profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[&](const profiler::RecordFunction& fn) {
if (fn.scope() == profiler::RecordScope::FUNCTION &&
std::string(fn.name().str()) == "test_function") {
@@ -779,26 +775,23 @@
std::string(fn.name().str()) == "test_user_scope") {
found_user_scope = true;
}
- return true;
},
- [](const profiler::RecordFunction&) {},
- /* needs_inputs */ false);
+ [](const profiler::RecordFunction&) {}));
bool bad_scope = false;
auto pushScopedCallback = [&](profiler::RecordScope scope, size_t& cnt) {
- profiler::pushCallback(
- [&bad_scope, &cnt, scope](const profiler::RecordFunction& fn) {
- if (fn.scope() == scope) {
- ++cnt;
- } else {
- bad_scope = true;
- }
- return true;
- },
- [](const profiler::RecordFunction&) {},
- /* needs_inputs */ false,
- /* sampling_prob */ 1.0,
- /* scopes */ {scope});
+ profiler::addGlobalCallback(
+ profiler::RecordFunctionCallback(
+ [&bad_scope, &cnt, scope](const profiler::RecordFunction& fn) {
+ if (fn.scope() == scope) {
+ ++cnt;
+ } else {
+ bad_scope = true;
+ }
+ return true;
+ },
+ [](const profiler::RecordFunction&) {})
+ .scopes({scope}));
};
size_t fun_cnt = 0;
@@ -835,28 +828,26 @@
// [(fn, [[sizes], [sizes], ...]), ...]
TracedTestInputs traced_inputs;
std::unordered_set<std::string> ts_names;
- autograd::profiler::pushCallback(
- [&](const autograd::profiler::RecordFunction& fn) {
- if (fn.scope() == autograd::profiler::RecordScope::FUNCTION) {
- auto inputs = fn.inputs();
- std::vector<std::vector<int64_t>> sizes;
- for (const auto& input : inputs) {
- if (input.isTensor()) {
- sizes.push_back(input.toTensor().sizes().vec());
- } else if (input.isScalar()) {
- sizes.push_back(std::vector<int64_t>());
+ addGlobalCallback(
+ RecordFunctionCallback(
+ [&](const RecordFunction& fn) {
+ if (fn.scope() == RecordScope::FUNCTION) {
+ auto inputs = fn.inputs();
+ std::vector<std::vector<int64_t>> sizes;
+ for (const auto& input : inputs) {
+ if (input.isTensor()) {
+ sizes.push_back(input.toTensor().sizes().vec());
+ } else if (input.isScalar()) {
+ sizes.push_back(std::vector<int64_t>());
+ }
+ }
+ traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
+ } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
+ ts_names.insert(fn.name().str());
}
- }
- traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
- } else if (
- fn.scope() ==
- autograd::profiler::RecordScope::TORCHSCRIPT_FUNCTION) {
- ts_names.insert(fn.name().str());
- }
- return true;
- },
- [](const autograd::profiler::RecordFunction&) {},
- /* needs_inputs */ true);
+ },
+ [](const RecordFunction&) {})
+ .needsInputs(true));
TracedTestInputs eager_inputs, jit_inputs;
{
@@ -876,7 +867,6 @@
jit_inputs = traced_inputs;
traced_inputs.clear();
}
- autograd::profiler::popCallback();
TORCH_CHECK(ts_names.size() == 2);
TORCH_CHECK(ts_names.find("forward") != ts_names.end());
@@ -884,31 +874,33 @@
checkTracedInputs(eager_inputs);
checkTracedInputs(jit_inputs);
- cleanUpScopeCallbacks();
+ profiler::clearCallbacks();
// test sampled callbacks
int sampled_cb_ctr = 0;
- autograd::profiler::pushCallback(
- [&sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
- if (std::string(fn.name().str()) == "test") {
- ++sampled_cb_ctr;
- }
- return true;
- },
- [](const autograd::profiler::RecordFunction&) {},
- /* needs_inputs */ false,
- /* sampling_prob */ 0.5);
+ auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) {
+ return addGlobalCallback(RecordFunctionCallback(
+ [&sampled_cb_ctr](const RecordFunction& fn) {
+ if (std::string(fn.name().str()) == "test") {
+ ++sampled_cb_ctr;
+ }
+ return true;
+ },
+ [](const RecordFunction&) {})
+ .samplingProb(sampling_prob));
+ };
int non_sampled_cb_ctr = 0;
- autograd::profiler::pushCallback(
- [&non_sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
+ addGlobalCallback(RecordFunctionCallback(
+ [&non_sampled_cb_ctr](const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++non_sampled_cb_ctr;
}
return true;
},
- [](const autograd::profiler::RecordFunction&) {},
- /* needs_inputs */ false);
+ [](const RecordFunction&) {}));
+
+ auto handle = setup_sampled_callback(0.5);
auto run_test_function = []() {
auto t = torch::randn({1, 2, 3}, at::kCPU);
@@ -922,45 +914,45 @@
TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
sampled_cb_ctr = 0;
- autograd::profiler::TEST_setGlobalSamplingProbability(0.0);
+ removeCallback(handle);
+ handle = setup_sampled_callback(0.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 2000);
TORCH_CHECK(sampled_cb_ctr == 0);
sampled_cb_ctr = 0;
- autograd::profiler::TEST_setGlobalSamplingProbability(1.0);
+ removeCallback(handle);
+ handle = setup_sampled_callback(1.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 3000);
TORCH_CHECK(sampled_cb_ctr == 1000);
- autograd::profiler::TEST_unsetGlobalSamplingProbability();
- cleanUpScopeCallbacks();
+ clearCallbacks();
// test the scope of the callbacks
checkScopeCallbacks();
- cleanUpScopeCallbacks();
+ clearCallbacks();
// check record function guard
std::vector<std::string> fn_names;
std::mutex mtx;
- autograd::profiler::pushCallback(
- [&fn_names, &mtx](const autograd::profiler::RecordFunction& fn) {
+ addGlobalCallback(RecordFunctionCallback(
+ [&fn_names, &mtx](const RecordFunction& fn) {
std::lock_guard<std::mutex> lock(mtx);
fn_names.push_back(fn.name().str());
return true;
},
- [](const autograd::profiler::RecordFunction&) {},
- /* needs_inputs */ false);
+ [](const RecordFunction&) {}));
{
- autograd::profiler::RecordFunctionGuard g1(false);
+ RecordFunctionGuard g1(false);
{
RECORD_USER_SCOPE("A");
{
- autograd::profiler::RecordFunctionGuard g2(true);
+ RecordFunctionGuard g2(true);
RECORD_USER_SCOPE("B");
{
- autograd::profiler::DisableRecordFunctionGuard g3;
+ DisableRecordFunctionGuard g3;
RECORD_USER_SCOPE("C");
}
}
@@ -969,7 +961,103 @@
}
TORCH_CHECK(fn_names.size() == 1);
TORCH_CHECK(fn_names[0] == "B");
- cleanUpScopeCallbacks();
+ clearCallbacks();
+
+ // test add/remove
+ std::vector<size_t> ids;
+ auto add_remove_test_add_cb = [&ids](size_t id) {
+ return addGlobalCallback(RecordFunctionCallback(
+ [&ids, id](const RecordFunction& fn) { ids.push_back(id); },
+ [](const RecordFunction&) {}));
+ };
+
+ auto h1 = add_remove_test_add_cb(1);
+ auto h2 = add_remove_test_add_cb(2);
+ auto h3 = add_remove_test_add_cb(3);
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ids.size() == 3);
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
+
+ ids.clear();
+ removeCallback(h1);
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ids.size() == 2);
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
+
+ ids.clear();
+ removeCallback(h3);
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ids.size() == 1);
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+
+ clearCallbacks();
+
+ // thread local / global callbacks
+
+ ids.clear();
+ addGlobalCallback(RecordFunctionCallback(
+ [&ids](const RecordFunction& fn) { ids.push_back(1); },
+ [](const RecordFunction&) {}));
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ids.size() == 1);
+ TORCH_CHECK(ids[0] == 1);
+ ids.clear();
+
+ auto th = std::thread([&ids]() {
+ c10::impl::IncludeDispatchKeyGuard observer_guard(
+ c10::DispatchKey::Profiler);
+ addThreadLocalCallback(RecordFunctionCallback(
+ [&ids](const RecordFunction& fn) { ids.push_back(2); },
+ [](const RecordFunction&) {}));
+
+ { RECORD_USER_SCOPE("test_thread"); }
+ });
+ th.join();
+ TORCH_CHECK(ids.size() == 2);
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
+ TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+ ids.clear();
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ids.size() == 1);
+ TORCH_CHECK(ids[0] == 1);
+ ids.clear();
+
+ // test should_run
+
+ bool ran = false;
+ bool should_run = false;
+ addGlobalCallback(
+ RecordFunctionCallback(
+ [&ran](const RecordFunction& fn) { ran = true; },
+ [](const RecordFunction&) {})
+ .setShouldRun([&should_run](const RecordFunctionCallback&) {
+ return should_run;
+ }));
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(!ran);
+
+ should_run = true;
+
+ { RECORD_USER_SCOPE("test"); }
+
+ TORCH_CHECK(ran);
+
+ clearCallbacks();
}
class TestThreadLocalDebugInfo : public at::DebugInfoBase {
@@ -1028,13 +1116,13 @@
TORCH_CHECK(
at::ThreadLocalDebugInfo::get(at::DebugInfoKind::TEST_INFO) == nullptr);
done = false;
- autograd::profiler::pushCallback(
- [&done](const autograd::profiler::RecordFunction&) {
+ auto handle = addGlobalCallback(RecordFunctionCallback(
+ [&done](const RecordFunction&) {
checkDebugInfo(at::DebugInfoKind::TEST_INFO, 42);
done = true;
return true;
},
- [](const autograd::profiler::RecordFunction&) {});
+ [](const RecordFunction&) {}));
{
at::DebugInfoGuard guard(at::DebugInfoKind::TEST_INFO, debug_info);
auto t = torch::randn({1, 2, 3}, at::kCPU);
@@ -1042,7 +1130,7 @@
auto t2 = t.pow(2);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
}
- autograd::profiler::popCallback();
+ removeCallback(handle);
TORCH_CHECK(done);
// check nested debug info
@@ -1094,7 +1182,7 @@
std::stringstream ss;
{
- autograd::profiler::RecordProfile guard(ss);
+ RecordProfile guard(ss);
for (size_t i = 0; i < 100; ++i) {
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
}
diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp
index ca14297..4e68868 100644
--- a/torch/csrc/autograd/profiler.cpp
+++ b/torch/csrc/autograd/profiler.cpp
@@ -35,6 +35,8 @@
// use RecordFunctionGuard to keep track of observers,
// enable/disableProfiler are tied to the code range
thread_local std::vector<std::shared_ptr<RecordFunctionGuard>> g_;
+// use thread_local vector to save profiler callback ids
+thread_local std::vector<uint64_t> callback_handles_;
} // namespace
@@ -143,7 +145,7 @@
throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running");
}
- pushCallback(
+ auto handle = addGlobalCallback(RecordFunctionCallback(
[config](const RecordFunction& fn) {
auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
if (config.report_input_shapes) {
@@ -196,11 +198,11 @@
} else {
popRange();
}
- },
- /* needs_inputs */ config.report_input_shapes,
- /* sampling_prob */ 1.0,
- /* scopes */ {RecordScope::FUNCTION, RecordScope::USER_SCOPE});
+ })
+ .needsInputs(config.report_input_shapes)
+ .scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
state = new_state;
+ callback_handles_.push_back(handle);
g_.emplace_back(std::make_shared<RecordFunctionGuard>());
if(state == ProfilerState::CUDA) {
@@ -230,10 +232,12 @@
ProfilerState old_state = state;
mark("__stop_profile");
- popCallback();
- state = ProfilerState::Disabled;
+ TORCH_INTERNAL_ASSERT(!callback_handles_.empty());
+ removeCallback(callback_handles_.back());
+ callback_handles_.pop_back();
TORCH_INTERNAL_ASSERT(!g_.empty());
g_.pop_back();
+ state = ProfilerState::Disabled;
if (old_state == ProfilerState::NVTX) {
return thread_event_lists();
diff --git a/torch/csrc/autograd/record_function.cpp b/torch/csrc/autograd/record_function.cpp
index 3a3611c..c784842 100644
--- a/torch/csrc/autograd/record_function.cpp
+++ b/torch/csrc/autograd/record_function.cpp
@@ -1,7 +1,7 @@
+#include <algorithm>
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/profiler.h>
-#include <torch/csrc/utils/memory.h>
#include <cstdlib>
#include <random>
@@ -11,177 +11,189 @@
namespace {
-float sample_zero_one() {
- static thread_local auto gen =
- torch::make_unique<std::mt19937>(std::random_device()());
- std::uniform_real_distribution<float> dist(0.0, 1.0);
- return dist(*gen);
+// Used to generate unique callback handles
+CallbackHandle next_unique_callback_handle() {
+ static std::atomic<uint64_t> unique_id {0};
+ return CallbackHandle(++unique_id);
}
+// Thread local vector of callbacks, holds pairs (callbacks, unique_id);
+// must be sorted in increasing handles order
+thread_local RecordFunctionCallbacks sorted_tls_callbacks_;
+
class CallbackManager {
public:
- void pushCallback(
- std::function<bool(const RecordFunction&)> start,
- std::function<void(const RecordFunction&)> end,
- bool needs_inputs,
- double sampling_prob,
- std::unordered_set<RecordScope, std::hash<RecordScope>> scopes) {
- callbacks_.emplace_back(
- std::move(start),
- std::move(end),
- needs_inputs,
- sampling_prob,
- std::move(scopes)
- );
- recomputeFlags();
-
- // make sure we mark the change in callbacks
- ++callbacks_version_;
+ CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) {
+ // note: monotonically increasing callbacks_unique_id keeps
+ // sorted_tls_callbacks_ sorted
+ auto handle = next_unique_callback_handle();
+ sorted_tls_callbacks_.emplace_back(std::move(cb), handle);
+ return handle;
}
- void popCallback() {
- if (callbacks_.empty()) {
- throw std::runtime_error("Empty callbacks stack");
+ CallbackHandle addGlobalCallback(RecordFunctionCallback cb) {
+ auto handle = next_unique_callback_handle();
+ sorted_global_callbacks_.emplace_back(std::move(cb), handle);
+ return handle;
+ }
+
+ void removeCallback(CallbackHandle handle) {
+ auto find_and_remove = [handle](RecordFunctionCallbacks& cbs) {
+ auto it = std::find_if(
+ cbs.begin(), cbs.end(),
+ [handle](
+ const std::pair<
+ RecordFunctionCallback,
+ CallbackHandle>& el) {
+ return el.second == handle;
+ });
+ if (it != cbs.end()) {
+ // keeps it sorted
+ cbs.erase(it);
+ return true;
+ }
+ return false;
+ };
+ auto found = find_and_remove(sorted_tls_callbacks_);
+ if (!found) {
+ found = find_and_remove(sorted_global_callbacks_);
}
- callbacks_.pop_back();
- recomputeFlags();
- ++callbacks_version_;
+ if (!found) {
+ LOG(WARNING) << "Requested callback is not found";
+ }
}
- inline bool hasCallbacks() const {
- return !callbacks_.empty();
+ void clearGlobalCallbacks() {
+ sorted_global_callbacks_.clear();
}
- inline bool needsInputs() const {
- return has_callbacks_with_inputs_;
+ void clearThreadLocalCallbacks() {
+ sorted_tls_callbacks_.clear();
+ }
+
+ inline bool hasGlobalCallbacks() const {
+ return !sorted_global_callbacks_.empty();
+ }
+
+ inline bool hasThreadLocalCallbacks() const {
+ return !sorted_tls_callbacks_.empty();
+ }
+
+ // init is called by RecordFunction in constructor to
+ // determine which thread local and global callbacks are going
+ // to be executed and whether any of them need inputs
+ inline void init(RecordFunction& rec_fn) {
+ auto scope = rec_fn.scope();
+ bool found_active_cb = false;
+ bool found_needs_inputs = false;
+ auto init_handles = [scope, &found_active_cb, &found_needs_inputs](
+ CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
+ handles.clear();
+ for (const auto& cb : cbs) {
+ if (cb.first.shouldRun(scope)) {
+ handles.push_back(cb.second);
+ found_active_cb = true;
+ if (cb.first.needsInputs()) {
+ found_needs_inputs = true;
+ }
+ }
+ }
+ };
+
+ init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
+ init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
+ rec_fn.active_ = found_active_cb;
+ rec_fn.needs_inputs_ = found_needs_inputs;
}
void runStartCallbacks(RecordFunction& rf) {
- rf._setCallbacksVersion(callbacks_version_);
- rf._activeCallbacks().clear();
- for (size_t cb_idx = 0; cb_idx < callbacks_.size(); ++cb_idx) {
- if (shouldRunCallback(cb_idx, rf.scope())) {
- try {
- bool cb_ret = callbacks_[cb_idx].start_cb_(rf);
- rf._activeCallbacks().push_back(cb_ret);
- } catch (const std::exception &e) {
- LOG(WARNING) << "Exception in RecordFunction start observer: "
- << e.what();
- rf._activeCallbacks().push_back(false);
- } catch (...) {
- LOG(WARNING) << "Exception in RecordFunction start observer: unknown";
- rf._activeCallbacks().push_back(false);
- }
- } else {
- rf._activeCallbacks().push_back(false);
- }
- }
+ mergeRunCallbacks(
+ sorted_global_callbacks_,
+ rf.sorted_active_global_handles_,
+ /* is_start */ true,
+ rf);
+ mergeRunCallbacks(
+ sorted_tls_callbacks_,
+ rf.sorted_active_tls_handles_,
+ /* is_start */ true,
+ rf);
}
void runEndCallbacks(RecordFunction& rf) {
- if (rf._callbacksVersion() == callbacks_version_) {
- for (size_t cb_idx = 0; cb_idx < rf._activeCallbacks().size(); ++cb_idx) {
- if (!rf._activeCallbacks()[cb_idx]) {
- continue;
- }
- try {
- callbacks_[cb_idx].end_cb_(rf);
- } catch (const std::exception &e) {
- LOG(WARNING) << "Exception in RecordFunction end observer: "
- << e.what();
- } catch (...) {
- LOG(WARNING) << "Exception in RecordFunction end observer: unknown";
- }
- }
- } else {
- C10_LOG_EVERY_MS(WARNING, 1000)
- << "Callbacks changed while running a record function, "
- << "you might be partially overlapping a record function "
- << "with a profiling scope";
- }
- }
-
- inline void TEST_setGlobalSamplingProbability(double sampling_prob) {
- global_prob_ = sampling_prob;
- use_global_prob_ = true;
- }
-
- inline void TEST_unsetGlobalSamplingProbability() {
- global_prob_ = 0.0;
- use_global_prob_ = false;
+ mergeRunCallbacks(
+ sorted_global_callbacks_,
+ rf.sorted_active_global_handles_,
+ /* is_start */ false,
+ rf);
+ mergeRunCallbacks(
+ sorted_tls_callbacks_,
+ rf.sorted_active_tls_handles_,
+ /* is_start */ false,
+ rf);
}
private:
- void recomputeFlags() {
- has_callbacks_with_inputs_ = false;
- for (const auto& cb : callbacks_) {
- has_callbacks_with_inputs_ |= cb.needs_inputs_;
+ bool tryRunCallback(
+ const std::function<void(const RecordFunction&)>& fn,
+ RecordFunction& rf) {
+ try {
+ fn(rf);
+ return true;
+ } catch (const std::exception &e) {
+ LOG(WARNING) << "Exception in RecordFunction callback: "
+ << e.what() << " , for the range " << rf.name();
+ return false;
+ } catch (...) {
+ LOG(WARNING) << "Exception in RecordFunction callback: unknown"
+ << " , for the range " << rf.name();
+ return false;
}
}
- inline double samplingProbability(size_t cb_idx) const {
- TORCH_INTERNAL_ASSERT(cb_idx < callbacks_.size());
- if (callbacks_[cb_idx].is_sampled_) {
- return use_global_prob_ ? global_prob_ : callbacks_[cb_idx].sampling_prob_;
- } else {
- return 1.0;
- }
- }
-
- inline bool shouldRunCallback(size_t cb_idx, RecordScope scope) const {
- TORCH_INTERNAL_ASSERT(cb_idx < callbacks_.size());
- return callbacks_[cb_idx].scopes_[static_cast<size_t>(scope)] &&
- (!callbacks_[cb_idx].is_sampled_ ||
- (sample_zero_one() < samplingProbability(cb_idx)));
- }
-
- struct Callback;
- std::vector<Callback> callbacks_;
-
- double global_prob_ = 0.0;
- bool use_global_prob_ = false;
- bool has_callbacks_with_inputs_ = false;
-
- // tracks the current 'version' of callbacks;
- // every time we push or pop callbacks, we bump this counter
- uint64_t callbacks_version_ = 0;
-
- struct Callback {
- Callback(
- std::function<bool(const RecordFunction&)> start_cb,
- std::function<void(const RecordFunction&)> end_cb,
- bool needs_inputs,
- double sampling_prob,
- std::unordered_set<RecordScope, std::hash<RecordScope>> scopes
- ) : start_cb_(std::move(start_cb)),
- end_cb_(std::move(end_cb)),
- needs_inputs_(needs_inputs),
- sampling_prob_(sampling_prob),
- is_sampled_(sampling_prob != 1.0) {
- if (!scopes.empty()) {
- scopes_.fill(false);
- for (auto sc : scopes) {
- scopes_[static_cast<size_t>(sc)] = true;
+ void mergeRunCallbacks(
+ const RecordFunctionCallbacks& sorted_callbacks,
+ const CallbackHandles& sorted_handles,
+ bool is_start,
+ RecordFunction& rf) {
+ size_t num_executed = 0;
+ size_t idx_c = 0;
+ for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
+ while (idx_c < sorted_callbacks.size() &&
+ sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
+ ++idx_c;
+ }
+ if (idx_c >= sorted_callbacks.size()) {
+ break;
+ }
+ if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
+ if (is_start) {
+ tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
+ } else {
+ tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
}
- } else {
- scopes_.fill(true);
+ ++num_executed;
}
}
- std::function<bool(const RecordFunction&)> start_cb_;
- std::function<void(const RecordFunction&)> end_cb_;
- std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_;
- const bool needs_inputs_;
- const double sampling_prob_;
- const bool is_sampled_;
- };
+ if (num_executed != sorted_handles.size()) {
+ C10_LOG_EVERY_MS(WARNING, 1000)
+ << "Could not match some of the start callbacks with the corresponding end callbacks, "
+ << "callbacks changed during RecordFunction lifetime; you might be trying to profile "
+ << "the code after profiler is finished";
+ }
+ }
+
+ // Global callbacks; must be sorted in increasing handle order
+ RecordFunctionCallbacks sorted_global_callbacks_;
};
-std::mutex next_thread_id_mutex_;
-uint16_t next_thread_id_ = 0;
-thread_local uint16_t current_thread_id_ = 0;
+// Enumerates thread ids logically;
+// note: std::this_thread::get_id may return potentially
+// reused thread id
+std::atomic<uint64_t> next_thread_id_ {0};
+thread_local uint64_t current_thread_id_ = 0;
-// points to the currently active RecordFunction
+// Points to the currently active RecordFunction
thread_local RecordFunction* current_record_func_ = nullptr;
inline CallbackManager& manager() {
@@ -191,44 +203,66 @@
} // namespace
+/* static */
+double RecordFunctionCallback::sample_zero_one() {
+ static thread_local auto gen =
+ torch::make_unique<std::mt19937>(std::random_device()());
+ std::uniform_real_distribution<double> dist(0.0, 1.0);
+ return dist(*gen);
+}
+
bool hasCallbacks() {
- return manager().hasCallbacks();
+ auto& m = manager();
+ return m.hasGlobalCallbacks() || m.hasThreadLocalCallbacks();
}
-void pushCallback(
- std::function<bool(const RecordFunction&)> start,
- std::function<void(const RecordFunction&)> end,
- bool needs_inputs,
- double sampling_prob,
- std::unordered_set<RecordScope, std::hash<RecordScope>> scopes) {
- manager().pushCallback(
- std::move(start),
- std::move(end),
- needs_inputs,
- sampling_prob,
- std::move(scopes));
+bool hasGlobalCallbacks() {
+ return manager().hasGlobalCallbacks();
}
-void popCallback() {
- manager().popCallback();
+bool hasThreadLocalCallbacks() {
+ return manager().hasThreadLocalCallbacks();
}
-bool observersEnabled() {
+CallbackHandle addThreadLocalCallback(
+ RecordFunctionCallback cb) {
+ return manager().addThreadLocalCallback(std::move(cb));
+}
+
+CallbackHandle addGlobalCallback(
+ RecordFunctionCallback cb) {
+ return manager().addGlobalCallback(std::move(cb));
+}
+
+void removeCallback(CallbackHandle handle) {
+ manager().removeCallback(handle);
+}
+
+void clearGlobalCallbacks() {
+ manager().clearGlobalCallbacks();
+}
+
+void clearThreadLocalCallbacks() {
+ manager().clearThreadLocalCallbacks();
+}
+
+void clearCallbacks() {
+ auto& m = manager();
+ m.clearGlobalCallbacks();
+ m.clearThreadLocalCallbacks();
+}
+
+bool isRecordFunctionEnabled() {
return c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::Profiler);
}
-void enableObservers(bool enable) {
+void enableRecordFunction(bool enable) {
c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::Profiler, enable);
}
-void _runBeforeCallbacks(RecordFunction* rf, const std::string& funcName) {
- TORCH_INTERNAL_ASSERT(rf != nullptr);
- rf->_before(funcName);
-}
-
RecordFunction::RecordFunction(RecordScope scope) : scope_(scope) {
- if (manager().hasCallbacks() && observersEnabled()) {
- active_ = true;
+ if (hasCallbacks() && isRecordFunctionEnabled()) {
+ manager().init(*this);
}
}
@@ -239,23 +273,9 @@
}
/* static */
-bool RecordFunction::_needsInputs() {
- return manager().needsInputs();
-}
-
-void TEST_setGlobalSamplingProbability(double sampling_prob) {
- manager().TEST_setGlobalSamplingProbability(sampling_prob);
-}
-
-void TEST_unsetGlobalSamplingProbability() {
- manager().TEST_unsetGlobalSamplingProbability();
-}
-
-/* static */
-uint16_t RecordFunction::currentThreadId() {
+uint64_t RecordFunction::currentThreadId() {
if (!current_thread_id_) {
// happens only once per thread
- std::lock_guard<std::mutex> guard(next_thread_id_mutex_);
current_thread_id_ = ++next_thread_id_;
}
return current_thread_id_;
@@ -267,8 +287,9 @@
}
name_ = StringView(name);
sequence_nr_ = sequence_nr;
+ thread_id_ = currentThreadId();
- processCallbacks();
+ manager().runStartCallbacks(*this);
}
void RecordFunction::_before(std::string name, int64_t sequence_nr) {
@@ -277,8 +298,9 @@
}
name_ = StringView(std::move(name));
sequence_nr_ = sequence_nr;
+ thread_id_ = currentThreadId();
- processCallbacks();
+ manager().runStartCallbacks(*this);
}
void RecordFunction::_before(Node* fn, int64_t sequence_nr) {
@@ -288,12 +310,8 @@
fn_ = fn;
name_ = StringView(fn->name());
sequence_nr_ = (sequence_nr >= 0) ? sequence_nr : fn->sequence_nr();
-
- processCallbacks();
-}
-
-void RecordFunction::processCallbacks() {
thread_id_ = currentThreadId();
+
manager().runStartCallbacks(*this);
}
diff --git a/torch/csrc/autograd/record_function.h b/torch/csrc/autograd/record_function.h
index 21bad21..a13e7a3 100644
--- a/torch/csrc/autograd/record_function.h
+++ b/torch/csrc/autograd/record_function.h
@@ -4,6 +4,7 @@
#include <ATen/ThreadLocalState.h>
#include <c10/util/SmallVector.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/utils/memory.h>
#include <functional>
@@ -81,7 +82,9 @@
};
// Soft limit on the number of callbacks to use;
-constexpr std::size_t kSoftLimitCallbacks = 32;
+constexpr std::size_t kSoftLimitCallbacks = 4;
+
+typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
struct TORCH_API RecordFunction {
// Default constructor is used with before function called afterwards:
@@ -114,7 +117,7 @@
// Retrieves the thread_id that this RecordFunction ran start callbacks with.
// Useful for writing thread safe end callbacks that may be potentially
// executed in a different thread (async ops)
- inline uint16_t getStartCallbacksThreadId() const {
+ inline uint64_t getStartCallbacksThreadId() const {
return thread_id_;
}
@@ -126,15 +129,10 @@
static RecordFunction* current();
// Returns logical thread_id for the current thread
- static uint16_t currentThreadId();
+ static uint64_t currentThreadId();
// Internal functions, do not use directly;
- // might be called from python's context manager
-
- // Returns whether this record function runs callbacks
- bool _active() const {
- return active_;
- }
+ // used in python's context manager
// _before functions initialize RecordFunction members and call
// start callbacks
@@ -160,7 +158,8 @@
_before(fn, current_sequence_nr);
}
- // Internal, only for the use within RECORD_FUNCTION macro;
+ // Internal, only for the use within RECORD_FUNCTION macro
+ // (i.e. stack based RecordFunctions with scope lifetime);
// sets this function as the current() thread local function;
// original value of current() is restored in destructor/_end
void _setCurrent();
@@ -169,79 +168,159 @@
void _end();
// Returns whether some of the callbacks require function inputs
- static bool _needsInputs();
+ bool _needsInputs();
- inline uint64_t _callbacksVersion() const {
- return callbacks_version_;
- }
-
- inline void _setCallbacksVersion(uint64_t cv) {
- callbacks_version_ = cv;
- }
-
- // Returns boolean set of active (ran start callback) callbacks
- inline c10::SmallVector<bool, kSoftLimitCallbacks>& _activeCallbacks() {
- return active_callbacks_;
- }
+ // Used internally to keep track of thread local and global callbacks
+ // that were picked to run; must be sorted;
+ // public because of anonymous "friend" class
+ CallbackHandles sorted_active_tls_handles_;
+ CallbackHandles sorted_active_global_handles_;
+ // Whether this RecordFunction runs any callbacks
+ bool active_ = false;
+ /// Whether any of the picked callbacks require inputs
+ bool needs_inputs_ = false;
private:
- void processCallbacks();
-
Node* fn_ = nullptr;
StringView name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
+
// parent_ points to the parent RecordFunction and must out live this;
// only to be used together with RECORD_FUNCTION macro
+ // (with stack based RecordFunction instances with scope lifetime)
RecordFunction* parent_ = nullptr;
- // Holds the status of the callbacks after executing start callbacks.
- // If a start callback was not called (sampling) or returned false
- // (error or skipping the run), then the corresponding value in
- // the small vector is false and the end callback won't be called,
- // otherwise the value is true.
- c10::SmallVector<bool, kSoftLimitCallbacks> active_callbacks_;
-
// is_current_ true means that this record function updates thread local
// current record function pointer;
// true only in case of scope-based record functions, i.e.
// RECORD_FUNCTION macro
bool is_current_ = false;
- bool active_ = false;
+
+ // Kind of scope this RecordFunction is observing
const RecordScope scope_;
- // The logical thread_id that this RecordFunction was created with.
- uint16_t thread_id_ = 0;
-
- // Callbacks' version this record function was started with.
- // Used to ensure that the set of callbacks was not changed
- // during the record function's lifetime, between start and
- // end invocations.
- uint64_t callbacks_version_ = 0;
+ // The logical thread_id that this RecordFunction was created with
+ uint64_t thread_id_ = 0;
};
-// Returns whether there're callbacks registered with pushCallback
-TORCH_API bool hasCallbacks();
+//
+// PyTorch callbacks/observers API:
+//
-// Internal only, do not use:
-// use C++ RECORD_* or python context manager record_function() instead;
-// Given a record function, run the (possibly sampled) start callbacks that have
-// been pushed via pushCallback().
-TORCH_API void _runBeforeCallbacks(
- RecordFunction* rf,
- const std::string& funcName);
+/**
+ * RecordFunctionCallback represents a pair of callbacks to be used with
+ * RecordFunction, members:
+ * start, end - the callbacks to run when entering and exiting the scope;
+ * needs_inputs - whether the callbacks need the inputs passed from the observed
+ * function/range; NOTE: passing the inputs incurs an additional overhead;
+ * sampling_probability - if not 1.0, then the callback is probabilistically sampled
+ * to run; NOTE: start and end callbacks always run as a pair and are sampled
+ * together;
+ * scopes - types of scopes to execute the callbacks on (see RecordScope);
+ * passing empty set means the callbacks will be executed for all possible
+ * scope types
+ * should_run - optional function that returns whether this callback should run;
+ * overwrites the effect of setting sampling_probability
+ */
+class TORCH_API RecordFunctionCallback {
+ public:
+ explicit RecordFunctionCallback(
+ std::function<void(const RecordFunction&)> start,
+ std::function<void(const RecordFunction&)> end =
+ [](const RecordFunction&) {}):
+ start_(std::move(start)),
+ end_(std::move(end)) {
+ scopes_.fill(true);
+ }
-// Used in tests, overrides sampling probability for all callbacks;
-TORCH_API void TEST_setGlobalSamplingProbability(double sampling_prob);
-TORCH_API void TEST_unsetGlobalSamplingProbability();
+ RecordFunctionCallback& needsInputs(bool needs_inputs) {
+ needs_inputs_ = needs_inputs;
+ return *this;
+ }
+
+ RecordFunctionCallback& samplingProb(double sampling_prob) {
+ TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0,
+ "Invalid sampling probability");
+ sampling_prob_ = sampling_prob;
+ return *this;
+ }
+
+ RecordFunctionCallback& scopes(
+ const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
+ if (!scopes.empty()) {
+ scopes_.fill(false);
+ for (auto sc : scopes) {
+ scopes_[static_cast<size_t>(sc)] = true;
+ }
+ } else {
+ scopes_.fill(true);
+ }
+ return *this;
+ }
+
+ RecordFunctionCallback& setShouldRun(
+ std::function<bool(const RecordFunctionCallback&)> should_run) {
+ should_run_ = std::move(should_run);
+ return *this;
+ }
+
+ inline bool needsInputs() const {
+ return needs_inputs_;
+ }
+
+ inline double samplingProb() const {
+ return sampling_prob_;
+ }
+
+ inline bool checkScope(RecordScope sc) const {
+ return scopes_[(size_t)sc];
+ }
+
+ inline const std::function<void(const RecordFunction&)>& start() const {
+ return start_;
+ }
+
+ inline const std::function<void(const RecordFunction&)>& end() const {
+ return end_;
+ }
+
+ // whether this callbacks should run in the given scope
+ inline bool shouldRun(RecordScope scope) const {
+ // first check whether this callback is interested in
+ // the given scope type
+ if (!checkScope(scope)) {
+ return false;
+ }
+ // if we have registered should_run_ function, use it
+ if (should_run_) {
+ return should_run_(*this);
+ }
+ // otherwise potentially do the uniform sampling
+ if (sampling_prob_ != 1.0) {
+ return (sample_zero_one() < sampling_prob_);
+ }
+ return true;
+ }
+
+ private:
+ std::function<void(const RecordFunction&)> start_;
+ std::function<void(const RecordFunction&)> end_;
+ std::function<bool(const RecordFunctionCallback&)> should_run_;
+ bool needs_inputs_ = false;
+ double sampling_prob_ = 1.0;
+ std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
+
+ static double sample_zero_one();
+};
// Using macro to minimize inputs copies,
// optional argument - function's seq_no
#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
torch::autograd::profiler::RecordFunction guard(scope); \
- if (guard._active()) { \
+ if (guard.active_) { \
guard._setCurrent(); \
- if (torch::autograd::profiler::RecordFunction::_needsInputs()) { \
+ if (guard.needs_inputs_) { \
guard._before(fn, inputs, ##__VA_ARGS__); \
} else { \
guard._before(fn, ##__VA_ARGS__); \
@@ -262,52 +341,108 @@
RECORD_FUNCTION_WITH_SCOPE( \
torch::autograd::profiler::RecordScope::USER_SCOPE, fn, {})
-/**
- * pushCallback adds a pair of callbacks to run with RecordFunction:
- * start, end - the callbacks to run when entering and exiting the scope;
- * if start callback returns false, end callback won't be executed;
- * needs_inputs - whether the callbacks need the inputs passed from the observed
- * function/range; NOTE: passing the inputs incurs an additional overhead;
- * sampling_prob - whether the callbacks are sampled and the sampling
- * probability;
- * scopes - types of scopes to execute the callbacks on (see RecordScope);
- * passing empty set means the callbacks will be executed for all possible
- * scope types
- *
- * WARNING: not thread safe, must not overlap with other PyTorch code execution
- */
-TORCH_API void pushCallback(
- std::function<bool(const RecordFunction&)> start,
- std::function<void(const RecordFunction&)> end =
- [](const RecordFunction&) {},
- bool needs_inputs = false,
- double sampling_prob = 1.0,
- std::unordered_set<RecordScope, std::hash<RecordScope>> scopes =
- std::unordered_set<RecordScope, std::hash<RecordScope>>());
+// Notes:
+// - two types of callbacks are provided: thread local and global
+// - thread local callbacks are added/removed only for the given thread
+// and are stored locally for each thread and separately from the list
+// of the global callbacks
+// - global callbacks are stored in a single per process list and are
+// invoked by every RecordFunction, in addition to the thread local
+// callbacks specific to the given thread
+// - we allow the added callbacks to be sampled, by specifying a sampling
+// probability for each callback pair, if the start callback is
+// not picked to run, the corresponding end callback won't be called
+// - a typical use case for the global callbacks is passive monitoring
+// in the background (e.g. fleet-wide monitoring), without focusing on
+// the specific peice of code
+// - in contrast, thread local callbacks are enabled locally, on demand,
+// for the specific piece of code (range) and are not sampled
+// - a typical use case for thread local callbacks is profiler and code
+// execution tracer
+// - note, some functionality (e.g. profiler) can automatically
+// propagate its calbacks across thread by using ThreadLocalState
+// mechanism, but in general callbacks are not propagated
+// - adding/removing global callbacks is not thread safe and should be done
+// only when no other code is running, e.g. during the initialization
+
+typedef uint64_t CallbackHandle;
+
+// Holds pairs (callbacks, unique_id)
+typedef std::vector<std::pair<RecordFunctionCallback, CallbackHandle>>
+ RecordFunctionCallbacks;
/**
- * popCallback removes the last pair of callbacks previously added with
- * pushCallback
- *
- * WARNING: not thread safe, must not overlap with other PyTorch code execution
+ * addThreadLocalCallback adds a thread local callback to run with RecordFunction,
+ * returns handle to use with removeThreadLocalCallback
*/
-TORCH_API void popCallback();
+TORCH_API CallbackHandle addThreadLocalCallback(
+ RecordFunctionCallback cb);
-// Enable observers thread locally
-TORCH_API void enableObservers(bool enable = true);
+/**
+ * hasThreadLocalCallbacks returns whether there're callbacks registered
+ * with addThreadLocalCallback
+ */
+TORCH_API bool hasThreadLocalCallbacks();
-// Returns whether observers are enabled (thread locally)
-TORCH_API bool observersEnabled();
+/**
+ * clearThreadLocalCallbacks removes all thread local callbacks
+ */
+TORCH_API void clearThreadLocalCallbacks();
+
+/**
+ * addGlobalCallback adds a global callback to run with RecordFunction:
+ *
+ * WARNING: not thread safe, typically addGlobalCallback can be called
+ * only during the program initialization
+ */
+TORCH_API CallbackHandle addGlobalCallback(
+ RecordFunctionCallback cb);
+
+/**
+ * removeCallback removes a callback given the handle returned by
+ * addThreadLocalCallback or addGlobalCallback;
+ *
+ * WARNING: removing a global callback is not thread safe,
+ * no other code can run simultaneously
+ */
+TORCH_API void removeCallback(CallbackHandle handle);
+
+/**
+ * hasGlobalCallbacks returns whether there're global callbacks
+ * registered with pushGlobalCallback
+ */
+TORCH_API bool hasGlobalCallbacks();
+
+/**
+ * clearGlobalCallbacks removes all global callbacks
+ * WARNING: not thread safe
+ */
+TORCH_API void clearGlobalCallbacks();
+
+// for both thread local and global callbacks
+TORCH_API bool hasCallbacks();
+TORCH_API void clearCallbacks(); // not thread safe
+
+/**
+ * enableRecordFunction enables RecordFunction thread locally
+ */
+TORCH_API void enableRecordFunction(bool enable = true);
+
+/**
+ * isRecordFunctionEnabled returns whether RecordFunction
+ * is enabled thread locally
+ */
+TORCH_API bool isRecordFunctionEnabled();
class TORCH_API RecordFunctionGuard {
public:
explicit RecordFunctionGuard(bool is_enabled = true)
- : prev_value_(observersEnabled()) {
- enableObservers(is_enabled);
+ : prev_value_(isRecordFunctionEnabled()) {
+ enableRecordFunction(is_enabled);
}
virtual ~RecordFunctionGuard() {
- enableObservers(prev_value_);
+ enableRecordFunction(prev_value_);
}
private:
@@ -321,4 +456,5 @@
};
} // namespace profiler
-}} // namespace torch::autograd
+} // namespace autograd
+} // namespace torch