Make RecordFunction callbacks thread local and modernize interface (#37491)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37491

This PR modernizes RecordFunction API and adds thread local callbacks
in addition to the global ones

Changes:
 - support for TLS callbacks, this is going to be the foundation of profiler and other tools
 - modernize interface around simple set of functions (add|remove|has|clear)(Global|ThreadLocal)(Callback) and adding RecordFunctionCallback to easily construct callbacks to be passed
 - we also add `.setShouldRun` into the callback interface to support cases when simple uniform sampling is not enough
 - to properly support add/remove introduce the idea of callback handle returned by add
 - internal implementation still uses SmallVector to store intermediate state (as before) - in this case these are vector of handles of callbacks that were picked to run
 - to speed up runtime we keep these vectors sorted, this way we can quickly enumerate callbacks that need to be run
 - added tests for new functionality

Test Plan:
BUILD_BINARY=1 USE_BLAS=MKL USE_MKLDNN=0 USE_CUDA=0 python setup.py
develop install
./build/bin/test_jit
CI

record_function_benchmark: https://gist.github.com/ilia-cher/f1e094dae47fe23e55e7672ac4dcda2f

Imported from OSS

Differential Revision: D21300448

fbshipit-source-id: 6d55c26dbf20b33d35c3f1604dcc07bb063c8c43
diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc
index b271466..34a21ad 100644
--- a/binaries/record_function_benchmark.cc
+++ b/binaries/record_function_benchmark.cc
@@ -22,22 +22,22 @@
 
 void setupCallbacks() {
   // non-sampled callback
-  profiler::pushCallback(
+  profiler::addGlobalCallback(profiler::RecordFunctionCallback(
       [&](const profiler::RecordFunction& fn) {
         return true;
       },
-      [](const profiler::RecordFunction&) {},
-      /* needs_inputs */ true);
+      [](const profiler::RecordFunction&) {})
+    .needsInputs(true));
 
   // sampled
   for (auto idx = 0; idx < kNumSampledCb; ++idx) {
-    profiler::pushCallback(
+    profiler::addGlobalCallback(profiler::RecordFunctionCallback(
         [](const profiler::RecordFunction& fn) {
           return true;
         },
-        [](const profiler::RecordFunction&) {},
-        /* needs_inputs */ true,
-        /* sampling_prob */ kSampingProb
+        [](const profiler::RecordFunction&) {})
+      .needsInputs(true)
+      .samplingProb(kSampingProb)
     );
   }
 }
diff --git a/docs/source/notes/large_scale_deployments.rst b/docs/source/notes/large_scale_deployments.rst
index 509c8fe..4c96d80 100644
--- a/docs/source/notes/large_scale_deployments.rst
+++ b/docs/source/notes/large_scale_deployments.rst
@@ -26,7 +26,7 @@
 across the entire set of machines.
 
 New callbacks for any operator invocation can be added with
-``torch::autograd::profiler::pushCallback``. Hooks will be called with
+``torch::autograd::profiler::addGlobalCallback``. Hooks will be called with
 ``torch::autograd::profiler::RecordFunction`` struct that describes invocation
 context (e.g. `name`). If enabled, ``RecordFunction::inputs()`` contains arguments
 of the function represented as ``torch::IValue`` variant type. Note, that inputs
@@ -42,9 +42,9 @@
 
 Invoking callbacks adds some overhead, so usually it's useful to just randomly
 sample operator invocations. This can be enabled on per-callback basis with an
-optional sampling rate passed into ``torch::autograd::profiler::pushCallback``.
+optional sampling rate passed into ``torch::autograd::profiler::addGlobalCallback``.
 
-Note, that ``pushCallback`` is not thread-safe and can be called only when no
+Note, that ``addGlobalCallback`` is not thread-safe and can be called only when no
 PyTorch operator is running. Usually, it's a good idea to call them once during
 initialization.
 
@@ -55,22 +55,20 @@
     // Called somewhere in the program beginning
     void init() {
         // Sample one in a hundred operator runs randomly
-        pushCallback(
+        addGlobalCallback(
+          RecordFunctionCallback(
             &onFunctionEnter,
-            &onFunctionExit,
-            /* needs_inputs */ true,
-            /* sampling_prob */ 0.01
+            &onFunctionExit)
+          .needsInputs(true)
+          .samplingProb(0.01)
         );
         // Note, to enable observers in the model calling thread,
-        // call enableObservers() in the thread before running a model
+        // call enableRecordFunction() in the thread before running a model
     }
 
-    bool onFunctionEnter(const RecordFunction& fn) {
+    void onFunctionEnter(const RecordFunction& fn) {
         std::cerr << "Before function " << fn.name()
                   << " with " << fn.inputs().size() << " inputs" << std::endl;
-        // Returning false would mean that the callback is not interested
-        // in this RecordFunction and onFunctionExit won't be called
-        return true;
     }
 
     void onFunctionExit(const RecordFunction& fn) {
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index 793b958..990dda4 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -67,6 +67,8 @@
 #include <utility>
 #include <vector>
 
+using namespace torch::autograd::profiler;
+
 namespace torch {
 namespace jit {
 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
@@ -755,17 +757,11 @@
 
 using namespace torch::autograd;
 
-void cleanUpScopeCallbacks() {
-  while (profiler::hasCallbacks()) {
-    profiler::popCallback();
-  }
-}
-
 void checkScopeCallbacks() {
   bool found_function_scope = false;
   bool found_method_scope = false;
   bool found_user_scope = false;
-  profiler::pushCallback(
+  profiler::addGlobalCallback(profiler::RecordFunctionCallback(
       [&](const profiler::RecordFunction& fn) {
         if (fn.scope() == profiler::RecordScope::FUNCTION &&
             std::string(fn.name().str()) == "test_function") {
@@ -779,26 +775,23 @@
             std::string(fn.name().str()) == "test_user_scope") {
           found_user_scope = true;
         }
-        return true;
       },
-      [](const profiler::RecordFunction&) {},
-      /* needs_inputs */ false);
+      [](const profiler::RecordFunction&) {}));
 
   bool bad_scope = false;
   auto pushScopedCallback = [&](profiler::RecordScope scope, size_t& cnt) {
-    profiler::pushCallback(
-        [&bad_scope, &cnt, scope](const profiler::RecordFunction& fn) {
-          if (fn.scope() == scope) {
-            ++cnt;
-          } else {
-            bad_scope = true;
-          }
-          return true;
-        },
-        [](const profiler::RecordFunction&) {},
-        /* needs_inputs */ false,
-        /* sampling_prob */ 1.0,
-        /* scopes */ {scope});
+    profiler::addGlobalCallback(
+        profiler::RecordFunctionCallback(
+            [&bad_scope, &cnt, scope](const profiler::RecordFunction& fn) {
+              if (fn.scope() == scope) {
+                ++cnt;
+              } else {
+                bad_scope = true;
+              }
+              return true;
+            },
+            [](const profiler::RecordFunction&) {})
+            .scopes({scope}));
   };
 
   size_t fun_cnt = 0;
@@ -835,28 +828,26 @@
   // [(fn, [[sizes], [sizes], ...]), ...]
   TracedTestInputs traced_inputs;
   std::unordered_set<std::string> ts_names;
-  autograd::profiler::pushCallback(
-      [&](const autograd::profiler::RecordFunction& fn) {
-        if (fn.scope() == autograd::profiler::RecordScope::FUNCTION) {
-          auto inputs = fn.inputs();
-          std::vector<std::vector<int64_t>> sizes;
-          for (const auto& input : inputs) {
-            if (input.isTensor()) {
-              sizes.push_back(input.toTensor().sizes().vec());
-            } else if (input.isScalar()) {
-              sizes.push_back(std::vector<int64_t>());
+  addGlobalCallback(
+      RecordFunctionCallback(
+          [&](const RecordFunction& fn) {
+            if (fn.scope() == RecordScope::FUNCTION) {
+              auto inputs = fn.inputs();
+              std::vector<std::vector<int64_t>> sizes;
+              for (const auto& input : inputs) {
+                if (input.isTensor()) {
+                  sizes.push_back(input.toTensor().sizes().vec());
+                } else if (input.isScalar()) {
+                  sizes.push_back(std::vector<int64_t>());
+                }
+              }
+              traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
+            } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
+              ts_names.insert(fn.name().str());
             }
-          }
-          traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
-        } else if (
-            fn.scope() ==
-            autograd::profiler::RecordScope::TORCHSCRIPT_FUNCTION) {
-          ts_names.insert(fn.name().str());
-        }
-        return true;
-      },
-      [](const autograd::profiler::RecordFunction&) {},
-      /* needs_inputs */ true);
+          },
+          [](const RecordFunction&) {})
+          .needsInputs(true));
 
   TracedTestInputs eager_inputs, jit_inputs;
   {
@@ -876,7 +867,6 @@
     jit_inputs = traced_inputs;
     traced_inputs.clear();
   }
-  autograd::profiler::popCallback();
 
   TORCH_CHECK(ts_names.size() == 2);
   TORCH_CHECK(ts_names.find("forward") != ts_names.end());
@@ -884,31 +874,33 @@
 
   checkTracedInputs(eager_inputs);
   checkTracedInputs(jit_inputs);
-  cleanUpScopeCallbacks();
+  profiler::clearCallbacks();
 
   // test sampled callbacks
   int sampled_cb_ctr = 0;
-  autograd::profiler::pushCallback(
-      [&sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
-        if (std::string(fn.name().str()) == "test") {
-          ++sampled_cb_ctr;
-        }
-        return true;
-      },
-      [](const autograd::profiler::RecordFunction&) {},
-      /* needs_inputs */ false,
-      /* sampling_prob */ 0.5);
+  auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) {
+    return addGlobalCallback(RecordFunctionCallback(
+                                 [&sampled_cb_ctr](const RecordFunction& fn) {
+                                   if (std::string(fn.name().str()) == "test") {
+                                     ++sampled_cb_ctr;
+                                   }
+                                   return true;
+                                 },
+                                 [](const RecordFunction&) {})
+                                 .samplingProb(sampling_prob));
+  };
 
   int non_sampled_cb_ctr = 0;
-  autograd::profiler::pushCallback(
-      [&non_sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
+  addGlobalCallback(RecordFunctionCallback(
+      [&non_sampled_cb_ctr](const RecordFunction& fn) {
         if (std::string(fn.name().str()) == "test") {
           ++non_sampled_cb_ctr;
         }
         return true;
       },
-      [](const autograd::profiler::RecordFunction&) {},
-      /* needs_inputs */ false);
+      [](const RecordFunction&) {}));
+
+  auto handle = setup_sampled_callback(0.5);
 
   auto run_test_function = []() {
     auto t = torch::randn({1, 2, 3}, at::kCPU);
@@ -922,45 +914,45 @@
   TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
 
   sampled_cb_ctr = 0;
-  autograd::profiler::TEST_setGlobalSamplingProbability(0.0);
+  removeCallback(handle);
+  handle = setup_sampled_callback(0.0);
   run_test_function();
 
   TORCH_CHECK(non_sampled_cb_ctr == 2000);
   TORCH_CHECK(sampled_cb_ctr == 0);
 
   sampled_cb_ctr = 0;
-  autograd::profiler::TEST_setGlobalSamplingProbability(1.0);
+  removeCallback(handle);
+  handle = setup_sampled_callback(1.0);
   run_test_function();
 
   TORCH_CHECK(non_sampled_cb_ctr == 3000);
   TORCH_CHECK(sampled_cb_ctr == 1000);
-  autograd::profiler::TEST_unsetGlobalSamplingProbability();
-  cleanUpScopeCallbacks();
+  clearCallbacks();
 
   // test the scope of the callbacks
   checkScopeCallbacks();
-  cleanUpScopeCallbacks();
+  clearCallbacks();
 
   // check record function guard
   std::vector<std::string> fn_names;
   std::mutex mtx;
-  autograd::profiler::pushCallback(
-      [&fn_names, &mtx](const autograd::profiler::RecordFunction& fn) {
+  addGlobalCallback(RecordFunctionCallback(
+      [&fn_names, &mtx](const RecordFunction& fn) {
         std::lock_guard<std::mutex> lock(mtx);
         fn_names.push_back(fn.name().str());
         return true;
       },
-      [](const autograd::profiler::RecordFunction&) {},
-      /* needs_inputs */ false);
+      [](const RecordFunction&) {}));
   {
-    autograd::profiler::RecordFunctionGuard g1(false);
+    RecordFunctionGuard g1(false);
     {
       RECORD_USER_SCOPE("A");
       {
-        autograd::profiler::RecordFunctionGuard g2(true);
+        RecordFunctionGuard g2(true);
         RECORD_USER_SCOPE("B");
         {
-          autograd::profiler::DisableRecordFunctionGuard g3;
+          DisableRecordFunctionGuard g3;
           RECORD_USER_SCOPE("C");
         }
       }
@@ -969,7 +961,103 @@
   }
   TORCH_CHECK(fn_names.size() == 1);
   TORCH_CHECK(fn_names[0] == "B");
-  cleanUpScopeCallbacks();
+  clearCallbacks();
+
+  // test add/remove
+  std::vector<size_t> ids;
+  auto add_remove_test_add_cb = [&ids](size_t id) {
+    return addGlobalCallback(RecordFunctionCallback(
+        [&ids, id](const RecordFunction& fn) { ids.push_back(id); },
+        [](const RecordFunction&) {}));
+  };
+
+  auto h1 = add_remove_test_add_cb(1);
+  auto h2 = add_remove_test_add_cb(2);
+  auto h3 = add_remove_test_add_cb(3);
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ids.size() == 3);
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
+
+  ids.clear();
+  removeCallback(h1);
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ids.size() == 2);
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
+
+  ids.clear();
+  removeCallback(h3);
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ids.size() == 1);
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+
+  clearCallbacks();
+
+  // thread local / global callbacks
+
+  ids.clear();
+  addGlobalCallback(RecordFunctionCallback(
+      [&ids](const RecordFunction& fn) { ids.push_back(1); },
+      [](const RecordFunction&) {}));
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ids.size() == 1);
+  TORCH_CHECK(ids[0] == 1);
+  ids.clear();
+
+  auto th = std::thread([&ids]() {
+    c10::impl::IncludeDispatchKeyGuard observer_guard(
+        c10::DispatchKey::Profiler);
+    addThreadLocalCallback(RecordFunctionCallback(
+        [&ids](const RecordFunction& fn) { ids.push_back(2); },
+        [](const RecordFunction&) {}));
+
+    { RECORD_USER_SCOPE("test_thread"); }
+  });
+  th.join();
+  TORCH_CHECK(ids.size() == 2);
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
+  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
+  ids.clear();
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ids.size() == 1);
+  TORCH_CHECK(ids[0] == 1);
+  ids.clear();
+
+  // test should_run
+
+  bool ran = false;
+  bool should_run = false;
+  addGlobalCallback(
+      RecordFunctionCallback(
+          [&ran](const RecordFunction& fn) { ran = true; },
+          [](const RecordFunction&) {})
+          .setShouldRun([&should_run](const RecordFunctionCallback&) {
+            return should_run;
+          }));
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(!ran);
+
+  should_run = true;
+
+  { RECORD_USER_SCOPE("test"); }
+
+  TORCH_CHECK(ran);
+
+  clearCallbacks();
 }
 
 class TestThreadLocalDebugInfo : public at::DebugInfoBase {
@@ -1028,13 +1116,13 @@
   TORCH_CHECK(
       at::ThreadLocalDebugInfo::get(at::DebugInfoKind::TEST_INFO) == nullptr);
   done = false;
-  autograd::profiler::pushCallback(
-      [&done](const autograd::profiler::RecordFunction&) {
+  auto handle = addGlobalCallback(RecordFunctionCallback(
+      [&done](const RecordFunction&) {
         checkDebugInfo(at::DebugInfoKind::TEST_INFO, 42);
         done = true;
         return true;
       },
-      [](const autograd::profiler::RecordFunction&) {});
+      [](const RecordFunction&) {}));
   {
     at::DebugInfoGuard guard(at::DebugInfoKind::TEST_INFO, debug_info);
     auto t = torch::randn({1, 2, 3}, at::kCPU);
@@ -1042,7 +1130,7 @@
     auto t2 = t.pow(2);
     t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
   }
-  autograd::profiler::popCallback();
+  removeCallback(handle);
   TORCH_CHECK(done);
 
   // check nested debug info
@@ -1094,7 +1182,7 @@
 
   std::stringstream ss;
   {
-    autograd::profiler::RecordProfile guard(ss);
+    RecordProfile guard(ss);
     for (size_t i = 0; i < 100; ++i) {
       std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
     }
diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp
index ca14297..4e68868 100644
--- a/torch/csrc/autograd/profiler.cpp
+++ b/torch/csrc/autograd/profiler.cpp
@@ -35,6 +35,8 @@
 // use RecordFunctionGuard to keep track of observers,
 // enable/disableProfiler are tied to the code range
 thread_local std::vector<std::shared_ptr<RecordFunctionGuard>> g_;
+// use thread_local vector to save profiler callback ids
+thread_local std::vector<uint64_t> callback_handles_;
 
 } // namespace
 
@@ -143,7 +145,7 @@
     throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running");
   }
 
-  pushCallback(
+  auto handle = addGlobalCallback(RecordFunctionCallback(
       [config](const RecordFunction& fn) {
         auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
         if (config.report_input_shapes) {
@@ -196,11 +198,11 @@
         } else {
           popRange();
         }
-      },
-      /* needs_inputs */ config.report_input_shapes,
-      /* sampling_prob */ 1.0,
-      /* scopes */ {RecordScope::FUNCTION, RecordScope::USER_SCOPE});
+      })
+      .needsInputs(config.report_input_shapes)
+      .scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
   state = new_state;
+  callback_handles_.push_back(handle);
   g_.emplace_back(std::make_shared<RecordFunctionGuard>());
 
   if(state == ProfilerState::CUDA) {
@@ -230,10 +232,12 @@
   ProfilerState old_state = state;
   mark("__stop_profile");
 
-  popCallback();
-  state = ProfilerState::Disabled;
+  TORCH_INTERNAL_ASSERT(!callback_handles_.empty());
+  removeCallback(callback_handles_.back());
+  callback_handles_.pop_back();
   TORCH_INTERNAL_ASSERT(!g_.empty());
   g_.pop_back();
+  state = ProfilerState::Disabled;
 
   if (old_state == ProfilerState::NVTX) {
     return thread_event_lists();
diff --git a/torch/csrc/autograd/record_function.cpp b/torch/csrc/autograd/record_function.cpp
index 3a3611c..c784842 100644
--- a/torch/csrc/autograd/record_function.cpp
+++ b/torch/csrc/autograd/record_function.cpp
@@ -1,7 +1,7 @@
+#include <algorithm>
 #include <torch/csrc/autograd/record_function.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/profiler.h>
-#include <torch/csrc/utils/memory.h>
 #include <cstdlib>
 #include <random>
 
@@ -11,177 +11,189 @@
 
 namespace {
 
-float sample_zero_one() {
-  static thread_local auto gen =
-      torch::make_unique<std::mt19937>(std::random_device()());
-  std::uniform_real_distribution<float> dist(0.0, 1.0);
-  return dist(*gen);
+// Used to generate unique callback handles
+CallbackHandle next_unique_callback_handle() {
+  static std::atomic<uint64_t> unique_id {0};
+  return CallbackHandle(++unique_id);
 }
 
+// Thread local vector of callbacks, holds pairs (callbacks, unique_id);
+// must be sorted in increasing handles order
+thread_local RecordFunctionCallbacks sorted_tls_callbacks_;
+
 class CallbackManager {
  public:
-  void pushCallback(
-      std::function<bool(const RecordFunction&)> start,
-      std::function<void(const RecordFunction&)> end,
-      bool needs_inputs,
-      double sampling_prob,
-      std::unordered_set<RecordScope, std::hash<RecordScope>> scopes) {
-    callbacks_.emplace_back(
-      std::move(start),
-      std::move(end),
-      needs_inputs,
-      sampling_prob,
-      std::move(scopes)
-    );
-    recomputeFlags();
-
-    // make sure we mark the change in callbacks
-    ++callbacks_version_;
+  CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) {
+    // note: monotonically increasing callbacks_unique_id keeps
+    // sorted_tls_callbacks_ sorted
+    auto handle = next_unique_callback_handle();
+    sorted_tls_callbacks_.emplace_back(std::move(cb), handle);
+    return handle;
   }
 
-  void popCallback() {
-    if (callbacks_.empty()) {
-      throw std::runtime_error("Empty callbacks stack");
+  CallbackHandle addGlobalCallback(RecordFunctionCallback cb) {
+    auto handle = next_unique_callback_handle();
+    sorted_global_callbacks_.emplace_back(std::move(cb), handle);
+    return handle;
+  }
+
+  void removeCallback(CallbackHandle handle) {
+    auto find_and_remove = [handle](RecordFunctionCallbacks& cbs) {
+      auto it = std::find_if(
+        cbs.begin(), cbs.end(),
+        [handle](
+            const std::pair<
+                RecordFunctionCallback,
+                CallbackHandle>& el) {
+          return el.second == handle;
+        });
+      if (it != cbs.end()) {
+        // keeps it sorted
+        cbs.erase(it);
+        return true;
+      }
+      return false;
+    };
+    auto found = find_and_remove(sorted_tls_callbacks_);
+    if (!found) {
+      found = find_and_remove(sorted_global_callbacks_);
     }
-    callbacks_.pop_back();
-    recomputeFlags();
-    ++callbacks_version_;
+    if (!found) {
+      LOG(WARNING) << "Requested callback is not found";
+    }
   }
 
-  inline bool hasCallbacks() const {
-    return !callbacks_.empty();
+  void clearGlobalCallbacks() {
+    sorted_global_callbacks_.clear();
   }
 
-  inline bool needsInputs() const {
-    return has_callbacks_with_inputs_;
+  void clearThreadLocalCallbacks() {
+    sorted_tls_callbacks_.clear();
+  }
+
+  inline bool hasGlobalCallbacks() const {
+    return !sorted_global_callbacks_.empty();
+  }
+
+  inline bool hasThreadLocalCallbacks() const {
+    return !sorted_tls_callbacks_.empty();
+  }
+
+  // init is called by RecordFunction in constructor to
+  // determine which thread local and global callbacks are going
+  // to be executed and whether any of them need inputs
+  inline void init(RecordFunction& rec_fn) {
+    auto scope = rec_fn.scope();
+    bool found_active_cb = false;
+    bool found_needs_inputs = false;
+    auto init_handles = [scope, &found_active_cb, &found_needs_inputs](
+        CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
+      handles.clear();
+      for (const auto& cb : cbs) {
+        if (cb.first.shouldRun(scope)) {
+          handles.push_back(cb.second);
+          found_active_cb = true;
+          if (cb.first.needsInputs()) {
+            found_needs_inputs = true;
+          }
+        }
+      }
+    };
+
+    init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
+    init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
+    rec_fn.active_ = found_active_cb;
+    rec_fn.needs_inputs_ = found_needs_inputs;
   }
 
   void runStartCallbacks(RecordFunction& rf) {
-    rf._setCallbacksVersion(callbacks_version_);
-    rf._activeCallbacks().clear();
-    for (size_t cb_idx = 0; cb_idx < callbacks_.size(); ++cb_idx) {
-      if (shouldRunCallback(cb_idx, rf.scope())) {
-        try {
-          bool cb_ret = callbacks_[cb_idx].start_cb_(rf);
-          rf._activeCallbacks().push_back(cb_ret);
-        } catch (const std::exception &e) {
-          LOG(WARNING) << "Exception in RecordFunction start observer: "
-                       << e.what();
-          rf._activeCallbacks().push_back(false);
-        } catch (...) {
-          LOG(WARNING) << "Exception in RecordFunction start observer: unknown";
-          rf._activeCallbacks().push_back(false);
-        }
-      } else {
-        rf._activeCallbacks().push_back(false);
-      }
-    }
+    mergeRunCallbacks(
+        sorted_global_callbacks_,
+        rf.sorted_active_global_handles_,
+        /* is_start */ true,
+        rf);
+    mergeRunCallbacks(
+        sorted_tls_callbacks_,
+        rf.sorted_active_tls_handles_,
+        /* is_start */ true,
+        rf);
   }
 
   void runEndCallbacks(RecordFunction& rf) {
-    if (rf._callbacksVersion() == callbacks_version_) {
-      for (size_t cb_idx = 0; cb_idx < rf._activeCallbacks().size(); ++cb_idx) {
-        if (!rf._activeCallbacks()[cb_idx]) {
-          continue;
-        }
-        try {
-          callbacks_[cb_idx].end_cb_(rf);
-        } catch (const std::exception &e) {
-          LOG(WARNING) << "Exception in RecordFunction end observer: "
-                       << e.what();
-        } catch (...) {
-          LOG(WARNING) << "Exception in RecordFunction end observer: unknown";
-        }
-      }
-    } else {
-      C10_LOG_EVERY_MS(WARNING, 1000)
-          << "Callbacks changed while running a record function, "
-          << "you might be partially overlapping a record function "
-          << "with a profiling scope";
-    }
-  }
-
-  inline void TEST_setGlobalSamplingProbability(double sampling_prob) {
-    global_prob_ = sampling_prob;
-    use_global_prob_ = true;
-  }
-
-  inline void TEST_unsetGlobalSamplingProbability() {
-    global_prob_ = 0.0;
-    use_global_prob_ = false;
+    mergeRunCallbacks(
+        sorted_global_callbacks_,
+        rf.sorted_active_global_handles_,
+        /* is_start */ false,
+        rf);
+    mergeRunCallbacks(
+        sorted_tls_callbacks_,
+        rf.sorted_active_tls_handles_,
+        /* is_start */ false,
+        rf);
   }
 
  private:
-  void recomputeFlags() {
-    has_callbacks_with_inputs_ = false;
-    for (const auto& cb : callbacks_) {
-      has_callbacks_with_inputs_ |= cb.needs_inputs_;
+  bool tryRunCallback(
+      const std::function<void(const RecordFunction&)>& fn,
+      RecordFunction& rf) {
+    try {
+      fn(rf);
+      return true;
+    } catch (const std::exception &e) {
+      LOG(WARNING) << "Exception in RecordFunction callback: "
+          << e.what() << " , for the range " << rf.name();
+      return false;
+    } catch (...) {
+      LOG(WARNING) << "Exception in RecordFunction callback: unknown"
+          << " , for the range " << rf.name();
+      return false;
     }
   }
 
-  inline double samplingProbability(size_t cb_idx) const {
-    TORCH_INTERNAL_ASSERT(cb_idx < callbacks_.size());
-    if (callbacks_[cb_idx].is_sampled_) {
-      return use_global_prob_ ? global_prob_ : callbacks_[cb_idx].sampling_prob_;
-    } else {
-      return 1.0;
-    }
-  }
-
-  inline bool shouldRunCallback(size_t cb_idx, RecordScope scope) const {
-    TORCH_INTERNAL_ASSERT(cb_idx < callbacks_.size());
-    return callbacks_[cb_idx].scopes_[static_cast<size_t>(scope)] &&
-           (!callbacks_[cb_idx].is_sampled_ ||
-            (sample_zero_one() < samplingProbability(cb_idx)));
-  }
-
-  struct Callback;
-  std::vector<Callback> callbacks_;
-
-  double global_prob_ = 0.0;
-  bool use_global_prob_ = false;
-  bool has_callbacks_with_inputs_ = false;
-
-  // tracks the current 'version' of callbacks;
-  // every time we push or pop callbacks, we bump this counter
-  uint64_t callbacks_version_ = 0;
-
-  struct Callback {
-    Callback(
-        std::function<bool(const RecordFunction&)> start_cb,
-        std::function<void(const RecordFunction&)> end_cb,
-        bool needs_inputs,
-        double sampling_prob,
-        std::unordered_set<RecordScope, std::hash<RecordScope>> scopes
-    ) : start_cb_(std::move(start_cb)),
-        end_cb_(std::move(end_cb)),
-        needs_inputs_(needs_inputs),
-        sampling_prob_(sampling_prob),
-        is_sampled_(sampling_prob != 1.0) {
-      if (!scopes.empty()) {
-        scopes_.fill(false);
-        for (auto sc : scopes) {
-          scopes_[static_cast<size_t>(sc)] = true;
+  void mergeRunCallbacks(
+      const RecordFunctionCallbacks& sorted_callbacks,
+      const CallbackHandles& sorted_handles,
+      bool is_start,
+      RecordFunction& rf) {
+    size_t num_executed = 0;
+    size_t idx_c = 0;
+    for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
+      while (idx_c < sorted_callbacks.size() &&
+            sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
+        ++idx_c;
+      }
+      if (idx_c >= sorted_callbacks.size()) {
+        break;
+      }
+      if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
+        if (is_start) {
+          tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
+        } else {
+          tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
         }
-      } else {
-        scopes_.fill(true);
+        ++num_executed;
       }
     }
 
-    std::function<bool(const RecordFunction&)> start_cb_;
-    std::function<void(const RecordFunction&)> end_cb_;
-    std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_;
-    const bool needs_inputs_;
-    const double sampling_prob_;
-    const bool is_sampled_;
-  };
+    if (num_executed != sorted_handles.size()) {
+      C10_LOG_EVERY_MS(WARNING, 1000)
+          << "Could not match some of the start callbacks with the corresponding end callbacks, "
+          << "callbacks changed during RecordFunction lifetime; you might be trying to profile "
+          << "the code after profiler is finished";
+    }
+  }
+
+  // Global callbacks; must be sorted in increasing handle order
+  RecordFunctionCallbacks sorted_global_callbacks_;
 };
 
-std::mutex next_thread_id_mutex_;
-uint16_t next_thread_id_ = 0;
-thread_local uint16_t current_thread_id_ = 0;
+// Enumerates thread ids logically;
+// note: std::this_thread::get_id may return potentially
+// reused thread id
+std::atomic<uint64_t> next_thread_id_ {0};
+thread_local uint64_t current_thread_id_ = 0;
 
-// points to the currently active RecordFunction
+// Points to the currently active RecordFunction
 thread_local RecordFunction* current_record_func_ = nullptr;
 
 inline CallbackManager& manager() {
@@ -191,44 +203,66 @@
 
 } // namespace
 
+/* static */
+double RecordFunctionCallback::sample_zero_one() {
+  static thread_local auto gen =
+      torch::make_unique<std::mt19937>(std::random_device()());
+  std::uniform_real_distribution<double> dist(0.0, 1.0);
+  return dist(*gen);
+}
+
 bool hasCallbacks() {
-  return manager().hasCallbacks();
+  auto& m = manager();
+  return m.hasGlobalCallbacks() || m.hasThreadLocalCallbacks();
 }
 
-void pushCallback(
-    std::function<bool(const RecordFunction&)> start,
-    std::function<void(const RecordFunction&)> end,
-    bool needs_inputs,
-    double sampling_prob,
-    std::unordered_set<RecordScope, std::hash<RecordScope>> scopes) {
-  manager().pushCallback(
-      std::move(start),
-      std::move(end),
-      needs_inputs,
-      sampling_prob,
-      std::move(scopes));
+bool hasGlobalCallbacks() {
+  return manager().hasGlobalCallbacks();
 }
 
-void popCallback() {
-  manager().popCallback();
+bool hasThreadLocalCallbacks() {
+  return manager().hasThreadLocalCallbacks();
 }
 
-bool observersEnabled() {
+CallbackHandle addThreadLocalCallback(
+    RecordFunctionCallback cb) {
+  return manager().addThreadLocalCallback(std::move(cb));
+}
+
+CallbackHandle addGlobalCallback(
+    RecordFunctionCallback cb) {
+  return manager().addGlobalCallback(std::move(cb));
+}
+
+void removeCallback(CallbackHandle handle) {
+  manager().removeCallback(handle);
+}
+
+void clearGlobalCallbacks() {
+  manager().clearGlobalCallbacks();
+}
+
+void clearThreadLocalCallbacks() {
+  manager().clearThreadLocalCallbacks();
+}
+
+void clearCallbacks() {
+  auto& m = manager();
+  m.clearGlobalCallbacks();
+  m.clearThreadLocalCallbacks();
+}
+
+bool isRecordFunctionEnabled() {
   return c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::Profiler);
 }
 
-void enableObservers(bool enable) {
+void enableRecordFunction(bool enable) {
   c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::Profiler, enable);
 }
 
-void _runBeforeCallbacks(RecordFunction* rf, const std::string& funcName) {
-  TORCH_INTERNAL_ASSERT(rf != nullptr);
-  rf->_before(funcName);
-}
-
 RecordFunction::RecordFunction(RecordScope scope) : scope_(scope) {
-  if (manager().hasCallbacks() && observersEnabled()) {
-    active_ = true;
+  if (hasCallbacks() && isRecordFunctionEnabled()) {
+    manager().init(*this);
   }
 }
 
@@ -239,23 +273,9 @@
 }
 
 /* static */
-bool RecordFunction::_needsInputs() {
-  return manager().needsInputs();
-}
-
-void TEST_setGlobalSamplingProbability(double sampling_prob) {
-  manager().TEST_setGlobalSamplingProbability(sampling_prob);
-}
-
-void TEST_unsetGlobalSamplingProbability() {
-  manager().TEST_unsetGlobalSamplingProbability();
-}
-
-/* static */
-uint16_t RecordFunction::currentThreadId() {
+uint64_t RecordFunction::currentThreadId() {
   if (!current_thread_id_) {
     // happens only once per thread
-    std::lock_guard<std::mutex> guard(next_thread_id_mutex_);
     current_thread_id_ = ++next_thread_id_;
   }
   return current_thread_id_;
@@ -267,8 +287,9 @@
   }
   name_ = StringView(name);
   sequence_nr_ = sequence_nr;
+  thread_id_ = currentThreadId();
 
-  processCallbacks();
+  manager().runStartCallbacks(*this);
 }
 
 void RecordFunction::_before(std::string name, int64_t sequence_nr) {
@@ -277,8 +298,9 @@
   }
   name_ = StringView(std::move(name));
   sequence_nr_ = sequence_nr;
+  thread_id_ = currentThreadId();
 
-  processCallbacks();
+  manager().runStartCallbacks(*this);
 }
 
 void RecordFunction::_before(Node* fn, int64_t sequence_nr) {
@@ -288,12 +310,8 @@
   fn_ = fn;
   name_ = StringView(fn->name());
   sequence_nr_ = (sequence_nr >= 0) ? sequence_nr : fn->sequence_nr();
-
-  processCallbacks();
-}
-
-void RecordFunction::processCallbacks() {
   thread_id_ = currentThreadId();
+
   manager().runStartCallbacks(*this);
 }
 
diff --git a/torch/csrc/autograd/record_function.h b/torch/csrc/autograd/record_function.h
index 21bad21..a13e7a3 100644
--- a/torch/csrc/autograd/record_function.h
+++ b/torch/csrc/autograd/record_function.h
@@ -4,6 +4,7 @@
 #include <ATen/ThreadLocalState.h>
 #include <c10/util/SmallVector.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/utils/memory.h>
 
 #include <functional>
 
@@ -81,7 +82,9 @@
 };
 
 // Soft limit on the number of callbacks to use;
-constexpr std::size_t kSoftLimitCallbacks = 32;
+constexpr std::size_t kSoftLimitCallbacks = 4;
+
+typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
 
 struct TORCH_API RecordFunction {
   // Default constructor is used with before function called afterwards:
@@ -114,7 +117,7 @@
   // Retrieves the thread_id that this RecordFunction ran start callbacks with.
   // Useful for writing thread safe end callbacks that may be potentially
   // executed in a different thread (async ops)
-  inline uint16_t getStartCallbacksThreadId() const {
+  inline uint64_t getStartCallbacksThreadId() const {
     return thread_id_;
   }
 
@@ -126,15 +129,10 @@
   static RecordFunction* current();
 
   // Returns logical thread_id for the current thread
-  static uint16_t currentThreadId();
+  static uint64_t currentThreadId();
 
   // Internal functions, do not use directly;
-  // might be called from python's context manager
-
-  // Returns whether this record function runs callbacks
-  bool _active() const {
-    return active_;
-  }
+  // used in python's context manager
 
   // _before functions initialize RecordFunction members and call
   // start callbacks
@@ -160,7 +158,8 @@
     _before(fn, current_sequence_nr);
   }
 
-  // Internal, only for the use within RECORD_FUNCTION macro;
+  // Internal, only for the use within RECORD_FUNCTION macro
+  // (i.e. stack based RecordFunctions with scope lifetime);
   // sets this function as the current() thread local function;
   // original value of current() is restored in destructor/_end
   void _setCurrent();
@@ -169,79 +168,159 @@
   void _end();
 
   // Returns whether some of the callbacks require function inputs
-  static bool _needsInputs();
+  bool _needsInputs();
 
-  inline uint64_t _callbacksVersion() const {
-    return callbacks_version_;
-  }
-
-  inline void _setCallbacksVersion(uint64_t cv) {
-    callbacks_version_ = cv;
-  }
-
-  // Returns boolean set of active (ran start callback) callbacks
-  inline c10::SmallVector<bool, kSoftLimitCallbacks>& _activeCallbacks() {
-    return active_callbacks_;
-  }
+  // Used internally to keep track of thread local and global callbacks
+  // that were picked to run; must be sorted;
+  // public because of anonymous "friend" class
+  CallbackHandles sorted_active_tls_handles_;
+  CallbackHandles sorted_active_global_handles_;
+  // Whether this RecordFunction runs any callbacks
+  bool active_ = false;
+  /// Whether any of the picked callbacks require inputs
+  bool needs_inputs_ = false;
 
  private:
-  void processCallbacks();
-
   Node* fn_ = nullptr;
   StringView name_;
   int64_t sequence_nr_ = -1;
   std::vector<c10::IValue> inputs_;
+
   // parent_ points to the parent RecordFunction and must out live this;
   // only to be used together with RECORD_FUNCTION macro
+  // (with stack based RecordFunction instances with scope lifetime)
   RecordFunction* parent_ = nullptr;
 
-  // Holds the status of the callbacks after executing start callbacks.
-  // If a start callback was not called (sampling) or returned false
-  // (error or skipping the run), then the corresponding value in
-  // the small vector is false and the end callback won't be called,
-  // otherwise the value is true.
-  c10::SmallVector<bool, kSoftLimitCallbacks> active_callbacks_;
-
   // is_current_ true means that this record function updates thread local
   // current record function pointer;
   // true only in case of scope-based record functions, i.e.
   // RECORD_FUNCTION macro
   bool is_current_ = false;
-  bool active_ = false;
+
+  // Kind of scope this RecordFunction is observing
   const RecordScope scope_;
 
-  // The logical thread_id that this RecordFunction was created with.
-  uint16_t thread_id_ = 0;
-
-  // Callbacks' version this record function was started with.
-  // Used to ensure that the set of callbacks was not changed
-  // during the record function's lifetime, between start and
-  // end invocations.
-  uint64_t callbacks_version_ = 0;
+  // The logical thread_id that this RecordFunction was created with
+  uint64_t thread_id_ = 0;
 };
 
-// Returns whether there're callbacks registered with pushCallback
-TORCH_API bool hasCallbacks();
+//
+// PyTorch callbacks/observers API:
+//
 
-// Internal only, do not use:
-// use C++ RECORD_* or python context manager record_function() instead;
-// Given a record function, run the (possibly sampled) start callbacks that have
-// been pushed via pushCallback().
-TORCH_API void _runBeforeCallbacks(
-    RecordFunction* rf,
-    const std::string& funcName);
+/**
+ * RecordFunctionCallback represents a pair of callbacks to be used with
+ * RecordFunction, members:
+ *   start, end - the callbacks to run when entering and exiting the scope;
+ *   needs_inputs - whether the callbacks need the inputs passed from the observed
+ *     function/range; NOTE: passing the inputs incurs an additional overhead;
+ *   sampling_probability - if not 1.0, then the callback is probabilistically sampled
+ *     to run; NOTE: start and end callbacks always run as a pair and are sampled
+ *     together;
+ *   scopes - types of scopes to execute the callbacks on (see RecordScope);
+ *     passing empty set means the callbacks will be executed for all possible
+ *     scope types
+ *   should_run - optional function that returns whether this callback should run;
+ *     overwrites the effect of setting sampling_probability
+ */
+class TORCH_API RecordFunctionCallback {
+ public:
+  explicit RecordFunctionCallback(
+      std::function<void(const RecordFunction&)> start,
+      std::function<void(const RecordFunction&)> end =
+        [](const RecordFunction&) {}):
+      start_(std::move(start)),
+      end_(std::move(end)) {
+    scopes_.fill(true);
+  }
 
-// Used in tests, overrides sampling probability for all callbacks;
-TORCH_API void TEST_setGlobalSamplingProbability(double sampling_prob);
-TORCH_API void TEST_unsetGlobalSamplingProbability();
+  RecordFunctionCallback& needsInputs(bool needs_inputs) {
+    needs_inputs_ = needs_inputs;
+    return *this;
+  }
+
+  RecordFunctionCallback& samplingProb(double sampling_prob) {
+    TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0,
+        "Invalid sampling probability");
+    sampling_prob_ = sampling_prob;
+    return *this;
+  }
+
+  RecordFunctionCallback& scopes(
+      const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
+    if (!scopes.empty()) {
+      scopes_.fill(false);
+      for (auto sc : scopes) {
+        scopes_[static_cast<size_t>(sc)] = true;
+      }
+    } else {
+      scopes_.fill(true);
+    }
+    return *this;
+  }
+
+  RecordFunctionCallback& setShouldRun(
+      std::function<bool(const RecordFunctionCallback&)> should_run) {
+    should_run_ = std::move(should_run);
+    return *this;
+  }
+
+  inline bool needsInputs() const {
+    return needs_inputs_;
+  }
+
+  inline double samplingProb() const {
+    return sampling_prob_;
+  }
+
+  inline bool checkScope(RecordScope sc) const {
+    return scopes_[(size_t)sc];
+  }
+
+  inline const std::function<void(const RecordFunction&)>& start() const {
+    return start_;
+  }
+
+  inline const std::function<void(const RecordFunction&)>& end() const {
+    return end_;
+  }
+
+  // whether this callbacks should run in the given scope
+  inline bool shouldRun(RecordScope scope) const {
+    // first check whether this callback is interested in
+    // the given scope type
+    if (!checkScope(scope)) {
+      return false;
+    }
+    // if we have registered should_run_ function, use it
+    if (should_run_) {
+      return should_run_(*this);
+    }
+    // otherwise potentially do the uniform sampling
+    if (sampling_prob_ != 1.0) {
+      return (sample_zero_one() < sampling_prob_);
+    }
+    return true;
+  }
+
+ private:
+  std::function<void(const RecordFunction&)> start_;
+  std::function<void(const RecordFunction&)> end_;
+  std::function<bool(const RecordFunctionCallback&)> should_run_;
+  bool needs_inputs_ = false;
+  double sampling_prob_ = 1.0;
+  std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
+
+  static double sample_zero_one();
+};
 
 // Using macro to minimize inputs copies,
 // optional argument - function's seq_no
 #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
   torch::autograd::profiler::RecordFunction guard(scope); \
-  if (guard._active()) { \
+  if (guard.active_) { \
     guard._setCurrent(); \
-    if (torch::autograd::profiler::RecordFunction::_needsInputs()) { \
+    if (guard.needs_inputs_) { \
       guard._before(fn, inputs, ##__VA_ARGS__); \
     } else { \
       guard._before(fn, ##__VA_ARGS__); \
@@ -262,52 +341,108 @@
   RECORD_FUNCTION_WITH_SCOPE( \
     torch::autograd::profiler::RecordScope::USER_SCOPE, fn, {})
 
-/**
- * pushCallback adds a pair of callbacks to run with RecordFunction:
- *  start, end - the callbacks to run when entering and exiting the scope;
- *    if start callback returns false, end callback won't be executed;
- *  needs_inputs - whether the callbacks need the inputs passed from the observed
- *    function/range; NOTE: passing the inputs incurs an additional overhead;
- *  sampling_prob - whether the callbacks are sampled and the sampling
- *    probability;
- *  scopes - types of scopes to execute the callbacks on (see RecordScope);
- *    passing empty set means the callbacks will be executed for all possible
- *    scope types
- *
- * WARNING: not thread safe, must not overlap with other PyTorch code execution
- */
-TORCH_API void pushCallback(
-    std::function<bool(const RecordFunction&)> start,
-    std::function<void(const RecordFunction&)> end =
-        [](const RecordFunction&) {},
-    bool needs_inputs = false,
-    double sampling_prob = 1.0,
-    std::unordered_set<RecordScope, std::hash<RecordScope>> scopes =
-        std::unordered_set<RecordScope, std::hash<RecordScope>>());
+// Notes:
+//  - two types of callbacks are provided: thread local and global
+//     - thread local callbacks are added/removed only for the given thread
+//       and are stored locally for each thread and separately from the list
+//       of the global callbacks
+//     - global callbacks are stored in a single per process list and are
+//       invoked by every RecordFunction, in addition to the thread local
+//       callbacks specific to the given thread
+//  - we allow the added callbacks to be sampled, by specifying a sampling
+//    probability for each callback pair, if the start callback is
+//    not picked to run, the corresponding end callback won't be called
+//  - a typical use case for the global callbacks is passive monitoring
+//    in the background (e.g. fleet-wide monitoring), without focusing on
+//    the specific peice of code
+//  - in contrast, thread local callbacks are enabled locally, on demand,
+//    for the specific piece of code (range) and are not sampled
+//  - a typical use case for thread local callbacks is profiler and code
+//    execution tracer
+//     - note, some functionality (e.g. profiler) can automatically
+//       propagate its calbacks across thread by using ThreadLocalState
+//       mechanism, but in general callbacks are not propagated
+//  - adding/removing global callbacks is not thread safe and should be done
+//    only when no other code is running, e.g. during the initialization
+
+typedef uint64_t CallbackHandle;
+
+// Holds pairs (callbacks, unique_id)
+typedef std::vector<std::pair<RecordFunctionCallback, CallbackHandle>>
+    RecordFunctionCallbacks;
 
 /**
- * popCallback removes the last pair of callbacks previously added with
- *  pushCallback
- *
- * WARNING: not thread safe, must not overlap with other PyTorch code execution
+ * addThreadLocalCallback adds a thread local callback to run with RecordFunction,
+ * returns handle to use with removeThreadLocalCallback
  */
-TORCH_API void popCallback();
+TORCH_API CallbackHandle addThreadLocalCallback(
+    RecordFunctionCallback cb);
 
-// Enable observers thread locally
-TORCH_API void enableObservers(bool enable = true);
+/**
+ * hasThreadLocalCallbacks returns whether there're callbacks registered
+ * with addThreadLocalCallback
+ */
+TORCH_API bool hasThreadLocalCallbacks();
 
-// Returns whether observers are enabled (thread locally)
-TORCH_API bool observersEnabled();
+/**
+ * clearThreadLocalCallbacks removes all thread local callbacks
+ */
+TORCH_API void clearThreadLocalCallbacks();
+
+/**
+ * addGlobalCallback adds a global callback to run with RecordFunction:
+ *
+ * WARNING: not thread safe, typically addGlobalCallback can be called
+ * only during the program initialization
+ */
+TORCH_API CallbackHandle addGlobalCallback(
+    RecordFunctionCallback cb);
+
+/**
+ * removeCallback removes a callback given the handle returned by
+ * addThreadLocalCallback or addGlobalCallback;
+ *
+ * WARNING: removing a global callback is not thread safe,
+ * no other code can run simultaneously
+ */
+TORCH_API void removeCallback(CallbackHandle handle);
+
+/**
+ * hasGlobalCallbacks returns whether there're global callbacks
+ * registered with pushGlobalCallback
+ */
+TORCH_API bool hasGlobalCallbacks();
+
+/**
+ * clearGlobalCallbacks removes all global callbacks
+ * WARNING: not thread safe
+ */
+TORCH_API void clearGlobalCallbacks();
+
+// for both thread local and global callbacks
+TORCH_API bool hasCallbacks();
+TORCH_API void clearCallbacks(); // not thread safe
+
+/**
+ * enableRecordFunction enables RecordFunction thread locally
+ */
+TORCH_API void enableRecordFunction(bool enable = true);
+
+/**
+ * isRecordFunctionEnabled returns whether RecordFunction
+ * is enabled thread locally
+ */
+TORCH_API bool isRecordFunctionEnabled();
 
 class TORCH_API RecordFunctionGuard {
  public:
   explicit RecordFunctionGuard(bool is_enabled = true)
-      : prev_value_(observersEnabled()) {
-    enableObservers(is_enabled);
+      : prev_value_(isRecordFunctionEnabled()) {
+    enableRecordFunction(is_enabled);
   }
 
   virtual ~RecordFunctionGuard() {
-    enableObservers(prev_value_);
+    enableRecordFunction(prev_value_);
   }
 
  private:
@@ -321,4 +456,5 @@
 };
 
 } // namespace profiler
-}} // namespace torch::autograd
+} // namespace autograd
+} // namespace torch