record_function: add torchbind alternative API (#72301)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72301

First step in resolving #35026.

This adds `PythonRecordFunction` which is a `torch::CustomClassHolder`
for `at::RecordFunction` to keep the ATen code free of torch includes.
And adds new unused internal API functions
`_record_function_enter_new` which return the torchbind object.

Once the FC period is expired, `torch.profiler.record_function` will
be updated to use this new internal API. Then once BC period is
expired, the cpp_custom_type_hack-based API can be removed.

Test Plan: Imported from OSS

Reviewed By: dagitses

Differential Revision: D34586311

Pulled By: robieta

fbshipit-source-id: d3eb9ffad7b348548a2b22c75203a92d1cb5115b
(cherry picked from commit 92d2ca808e5fbd20c9d6645dcabc3f059f9ef2d3)
diff --git a/aten/src/ATen/core/dispatch/ObservedOperators.cpp b/aten/src/ATen/core/dispatch/ObservedOperators.cpp
index 1d1ed4c..65545a2 100644
--- a/aten/src/ATen/core/dispatch/ObservedOperators.cpp
+++ b/aten/src/ATen/core/dispatch/ObservedOperators.cpp
@@ -15,6 +15,7 @@
     "aten::_version",
     "aten::is_complex",
     "profiler::_record_function_enter",
+    "profiler::_record_function_enter_new",
     "profiler::_record_function_exit",
   };
   return not_observed_ops;
diff --git a/test/test_autograd.py b/test/test_autograd.py
index c1fa75f..0cea90e 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -2931,6 +2931,21 @@
         foo_event = [event for event in function_events if "foo" in event.name][0]
         self.assertEqual(foo_event.count, 1)
 
+    def test_record_function_new_signatures(self):
+        # Test the new _record_function ops work
+        # Note: Remove once record_function uses these directly
+        x = torch.randn(10, 10)
+        with profile(use_kineto=kineto_available()) as p:
+            record = torch.ops.profiler._record_function_enter_new("bar", None)
+            try:
+                y = x * 2 + 4
+            finally:
+                torch.ops.profiler._record_function_exit(record)
+
+        function_events = p.function_events
+        foo_event = [event for event in function_events if "bar" in event.name][0]
+        self.assertEqual(foo_event.count, 1)
+
     def test_profiler_aggregation_fake(self):
         events = EventList()
         id = [0]
diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py
index 6062c09..dc505fb 100644
--- a/torch/autograd/profiler_util.py
+++ b/torch/autograd/profiler_util.py
@@ -642,6 +642,7 @@
     filtered_out_names = [
         MEMORY_EVENT_NAME,  # used only for the top-level memory events
         "profiler::_record_function_enter",
+        "profiler::_record_function_enter_new",
         "profiler::_record_function_exit",
         "aten::is_leaf",
         "aten::output_nr",
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 8499fd9..289f1e8 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -9,7 +9,6 @@
 #include <torch/csrc/autograd/grad_mode.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <ATen/autocast_mode.h>
-#include <ATen/cpp_custom_type_hack.h>
 #include <ATen/record_function.h>
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/autograd/profiler_python.h>
@@ -21,6 +20,7 @@
 #include <torch/csrc/autograd/utils/python_arg_parsing.h>
 #include <torch/csrc/autograd/python_mode.h>
 #include <torch/csrc/autograd/python_variable.h>
+#include <torch/csrc/autograd/record_function_ops.h>
 #include <torch/csrc/utils/pycfunction_helpers.h>
 #include <c10/core/ScalarType.h>
 
@@ -241,7 +241,9 @@
   // Creates a new profiling scope using RecordFunction and invokes its starting
   // callbacks.
   m.def("_record_function_with_args_enter", [](const std::string& name, py::args args) {
-    auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
+    using torch::autograd::profiler::PythonRecordFunction;
+    auto python_rec = c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
+    auto *rec = &python_rec->record;
     if (rec->isActive()) {
       if (rec->needsInputs()) {
         auto iv_inputs = std::vector<c10::IValue>();
@@ -253,16 +255,19 @@
         rec->before(name);
       }
     }
-    return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
+    return torch::jit::toPyObject(std::move(python_rec));
   });
 
   // Ends the profiling scope created with record_function_with_param_enter.
-  m.def("_record_function_with_args_exit", [](const at::Tensor& handle) {
-    // We don't actually need to do anything with handle just need to persist the
-    // lifetime until now.
-    auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
-    rec.end();
-  });
+  m.def("_record_function_with_args_exit",
+        [](const py::object &obj) {
+          using torch::autograd::profiler::PythonRecordFunction;
+          auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj);
+
+          // We don't actually need to do anything with handle just need to persist the
+          // lifetime until now.
+          python_record->record.end();
+        });
 
   m.def("_supported_activities", []() {
     std::set<ActivityType> activities {ActivityType::CPU};
diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp
index 2cf427e..ad8bf33 100644
--- a/torch/csrc/autograd/record_function_ops.cpp
+++ b/torch/csrc/autograd/record_function_ops.cpp
@@ -1,8 +1,10 @@
+#include <torch/csrc/autograd/record_function_ops.h>
 #include <ATen/cpp_custom_type_hack.h>
 #include <ATen/record_function.h>
 #include <ATen/ThreadLocalState.h>
 
-#include <torch/csrc/jit/runtime/custom_operator.h>
+#include <torch/library.h>
+#include <torch/csrc/jit/runtime/operator.h>
 
 namespace caffe2 {
 // Required for cpp_custom_type_hack to work
@@ -16,47 +18,68 @@
 
 // Creates a new profiling scope using RecordFunction and invokes its starting
 // callbacks.
-at::Tensor record_function_enter(
+void record_function_enter(
+    const std::string& name,
+    const c10::optional<std::string>& args,
+    at::RecordFunction &rec) {
+  if (rec.isActive()) {
+    if (rec.needsInputs() && args.has_value()) {
+      rec.before(name, std::vector<c10::IValue>{c10::IValue{args.value()}});
+    } else {
+      rec.before(name);
+    }
+  }
+}
+
+// Legacy signature using cpp_custom_type_hack
+at::Tensor record_function_enter_legacy(
     const std::string& name,
     const c10::optional<std::string>& args) {
   auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
-  if (rec->isActive()) {
-    if (rec->needsInputs() && args.has_value()) {
-      rec->before(name, std::vector<c10::IValue>{c10::IValue{args.value()}});
-    } else {
-      rec->before(name);
-    }
-  }
+  record_function_enter(name, args, *rec);
   return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
 }
 
+// New signature using custom_class
+c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
+    const std::string &name, const c10::optional<std::string> &args) {
+  auto rec = c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
+  record_function_enter(name, args, rec->record);
+  return rec;
+}
+
 at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
   auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
   return rec;
 }
 
 // Ends the profiling scope created with record_function_enter.
-void record_function_exit(const at::Tensor& handle) {
-  // We don't actually need to do anything with handle just need to persist the
-  // lifetime until now.
-  auto& rec = getRecordFunctionFromTensor(handle);
+void record_function_exit(at::RecordFunction &rec) {
   rec.end();
 }
 
+// Legacy signature using cpp_custom_type_hack
+void record_function_exit_legacy(const at::Tensor &handle) {
+  // We don't actually need to do anything with handle just need to persist the
+  // lifetime until now.
+  auto& rec = getRecordFunctionFromTensor(handle);
+  record_function_exit(rec);
+}
+
+// New signature using custom_class
+void record_function_exit_new(const c10::intrusive_ptr<PythonRecordFunction> &record) {
+  record_function_exit(record->record);
+}
+
+template <typename Func>
 c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
-    const at::Tensor& handle,
+    Func get_record,
     const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
   // Profiling callback that ends the associated record_function
   // and returns the value of the passed in future.
   std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc =
-      [handle](c10::ivalue::Future& fut) {
-        TORCH_INTERNAL_ASSERT(
-            handle.defined(),
-            "Undefined RecordFunction handle. This can happen if the handle is "
-            "not correctly persisted and is destroyed before the future is "
-            "realized.");
-
-        auto& rec = getRecordFunctionFromTensor(handle);
+      [get_record = std::move(get_record)](c10::ivalue::Future& fut) {
+        auto& rec = get_record();
         rec.end();
         // Note: this future is returned to the user to ensure that a call to wait()
         // ensures that profiling callbacks have ran. To ensure that this is
@@ -67,36 +90,74 @@
       };
   // Define a future that completes after the profiling callbacks are run.
   auto profiledFut = fut->then(at::wrapPropagateTLSState(
-      futureProfilingFunc),
+      std::move(futureProfilingFunc)),
       fut->elementType()
       );
   return profiledFut;
 }
 
+// Legacy signature using cpp_custom_type_hack
+c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
+    const at::Tensor &handle,
+    const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
+  return _call_end_callbacks_on_fut(
+      [handle] () -> at::RecordFunction& {
+        TORCH_INTERNAL_ASSERT(
+            handle.defined(),
+            "Undefined RecordFunction handle. This can happen if the handle is "
+            "not correctly persisted and is destroyed before the future is "
+            "realized.");
+
+        return getRecordFunctionFromTensor(handle);
+      },
+      fut
+    );
+}
+
+// New signature using custom_class
+c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
+    const c10::intrusive_ptr<PythonRecordFunction> &record,
+    const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
+  return _call_end_callbacks_on_fut(
+      [record] () -> at::RecordFunction& { return record->record; }, fut);
+}
+
 // Internal only, do not use directly, use Python's record_function()
 TORCH_LIBRARY_FRAGMENT(profiler, m) {
-    m.def("_record_function_enter(str name, str? args=None) -> Tensor", &record_function_enter);
-    m.def("_record_function_exit", &record_function_exit);
-}
+  m.class_<PythonRecordFunction>("_RecordFunction");
 
-// Needed to register JIT operator in operator registry below
-c10::AliasAnalysisKind aliasAnalysisFromSchema() {
-  return c10::AliasAnalysisKind::FROM_SCHEMA;
-}
+  m.def("_record_function_enter(str name, str? args=None) -> Tensor",
+        &record_function_enter_legacy);
+  m.def("_record_function_enter_new(str name, str? args=None) -> "
+        "__torch__.torch.classes.profiler._RecordFunction",
+        &record_function_enter_new);
+  m.def("_record_function_exit", &record_function_exit_legacy);
+  m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
 
-jit::RegisterOperators reg_fut_ops({
-    jit::Operator(
+  torch::jit::registerOperator(torch::jit::Operator(
         "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
         [](jit::Stack& stack) {
           // Pop inputs, which should be a future and a tensor
           auto fut = jit::pop(stack).toFuture();
           auto tensor = jit::pop(stack).toTensor();
-          auto profiledFut = _call_end_callbacks_on_fut(tensor, fut);
+          auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
           // return future that completes when profiling callbacks have run.
           jit::push(stack, std::move(profiledFut));
         },
-        aliasAnalysisFromSchema()),
-});
+        c10::AliasAnalysisKind::FROM_SCHEMA));
+  torch::jit::registerOperator(torch::jit::Operator(
+        "profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
+            "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
+        [](c10::Stack &stack) {
+          // Pop inputs, which should be a future and a PythonRecordFunction
+          auto fut = torch::jit::pop(stack).toFuture();
+          auto tensor = torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
+          auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
+          // return future that completes when profiling callbacks have run.
+          torch::jit::push(stack, std::move(profiledFut));
+        },
+        c10::AliasAnalysisKind::FROM_SCHEMA));
+}
 
 } // namespace profiler
 } // namespace autograd
diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h
index 9042537..81cc584 100644
--- a/torch/csrc/autograd/record_function_ops.h
+++ b/torch/csrc/autograd/record_function_ops.h
@@ -1,17 +1,30 @@
 #pragma once
 #include <ATen/record_function.h>
 #include <c10/util/Optional.h>
+#include <torch/custom_class.h>
 
 namespace torch {
 namespace autograd {
 namespace profiler {
+
+struct PythonRecordFunction: public torch::CustomClassHolder {
+  at::RecordFunction record;
+
+  PythonRecordFunction(
+      at::RecordScope scope = at::RecordScope::FUNCTION,
+      bool pre_sampled = false)
+    : record(scope, pre_sampled)
+    {}
+};
+
 // Creates a new profiling scope using RecordFunction and invokes its starting
 // callbacks.
-TORCH_API at::Tensor record_function_enter(const std::string& name, const c10::optional<std::string>& args = c10::nullopt);
+TORCH_API c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
+    const std::string &name, const c10::optional<std::string> &args = c10::nullopt);
 
 // Schedules RecordFunction's end callbacks to be run on completion of a future.
-TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
-    const at::Tensor& handle,
+TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
+    const c10::intrusive_ptr<PythonRecordFunction> &record,
     const c10::intrusive_ptr<c10::ivalue::Future>& fut);
 
 } // namespace profiler
diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp
index 464a290..2d54c52 100644
--- a/torch/csrc/distributed/rpc/torchscript_functions.cpp
+++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp
@@ -21,10 +21,7 @@
     std::vector<c10::IValue>& stack,
     const float rpcTimeoutSeconds,
     const bool isAsyncExecution) {
-  // This dummy tensor holds an at::RecordFunction when profiling is enabled.
-  // This is because at::RecordFunction is not yet registered as a TorchScript
-  // custom class (https://github.com/pytorch/pytorch/issues/35026)
-  at::Tensor handle = at::zeros(1);
+  c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
   auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
       !torch::distributed::rpc::RemoteProfilerManager::getInstance()
            .isCurrentKeySet();
@@ -35,7 +32,7 @@
             .qualifiedName(), /* name of torchscript function being run */
         RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
         dstWorkerName);
-    handle = torch::autograd::profiler::record_function_enter(rpcAsyncJitKey);
+    record = torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
     auto& remoteProfilerManager =
         torch::distributed::rpc::RemoteProfilerManager::getInstance();
     remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
@@ -75,7 +72,7 @@
   }));
   if (shouldProfile) {
     auto profiledFutPtr =
-        torch::autograd::profiler::_call_end_callbacks_on_fut(handle, futPtr);
+        torch::autograd::profiler::_call_end_callbacks_on_fut_new(record, futPtr);
     return profiledFutPtr;
   }
   return futPtr;
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 4676287..a922376 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -587,6 +587,16 @@
 // python_ivalue.h
 IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N);
 
+// Extract custom class registered with torchbind
+template <typename T>
+c10::intrusive_ptr<T> toCustomClass(py::handle obj) {
+  static_assert(
+      std::is_base_of<CustomClassHolder, T>::value, "T is not a CustomClass");
+  const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+  c10::IValue ivalue = toIValue(obj, type);
+  return std::move(ivalue).toCustomClass<T>();
+}
+
 // Small wrapper around getting the type name string from Python to make
 // types easier to interpret, e.g. give the structural type for a NamedTuple
 inline std::string friendlyTypeName(py::handle obj) {
diff --git a/torch/fx/node.py b/torch/fx/node.py
index eefa731..4311094 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -26,7 +26,9 @@
 ]]
 
 _side_effectful_functions: Set[Callable] = {
-    torch._assert, torch.ops.profiler._record_function_enter,
+    torch._assert,
+    torch.ops.profiler._record_function_enter,
+    torch.ops.profiler._record_function_enter_new,
     torch.ops.profiler._record_function_exit}
 
 # this is fixed on master, WAR for 1.5
diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
index 5ba47c7..383e627 100644
--- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -153,8 +153,11 @@
     t: Tensor = torch.ones(1)
     with record_function(block) as rf:
         fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
+        # Extra operator call to avoid de-duplication of the next async call
+        # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
+        zero = torch.zeros_like(t)
         fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
-        res = fut1.wait() + fut2.wait()
+        res = fut1.wait() + fut2.wait() + zero
     return res
 
 
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index 1d69808..1d11a85 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -2381,6 +2381,24 @@
                 fut.wait()
 
     @dist_init
+    def test_async_record_function_double_end_callbacks_new_signatures(self):
+        # Test the new _record_function ops work
+        # Note: Remove once record_function uses these directly
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            with _profile() as pf:
+                try:
+                    record = torch.ops.profiler._record_function_enter_new("foo", None)
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+                finally:
+                    torch.ops.profiler._record_function_exit(record)
+
+                fut.wait()
+
+    @dist_init
     def test_async_record_function_cbs_jit_call(self):
         if self.rank == 1:
             with _profile() as pf: