record_function: add torchbind alternative API (#72301)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72301
First step in resolving #35026.
This adds `PythonRecordFunction` which is a `torch::CustomClassHolder`
for `at::RecordFunction` to keep the ATen code free of torch includes.
And adds new unused internal API functions
`_record_function_enter_new` which return the torchbind object.
Once the FC period is expired, `torch.profiler.record_function` will
be updated to use this new internal API. Then once BC period is
expired, the cpp_custom_type_hack-based API can be removed.
Test Plan: Imported from OSS
Reviewed By: dagitses
Differential Revision: D34586311
Pulled By: robieta
fbshipit-source-id: d3eb9ffad7b348548a2b22c75203a92d1cb5115b
(cherry picked from commit 92d2ca808e5fbd20c9d6645dcabc3f059f9ef2d3)
diff --git a/aten/src/ATen/core/dispatch/ObservedOperators.cpp b/aten/src/ATen/core/dispatch/ObservedOperators.cpp
index 1d1ed4c..65545a2 100644
--- a/aten/src/ATen/core/dispatch/ObservedOperators.cpp
+++ b/aten/src/ATen/core/dispatch/ObservedOperators.cpp
@@ -15,6 +15,7 @@
"aten::_version",
"aten::is_complex",
"profiler::_record_function_enter",
+ "profiler::_record_function_enter_new",
"profiler::_record_function_exit",
};
return not_observed_ops;
diff --git a/test/test_autograd.py b/test/test_autograd.py
index c1fa75f..0cea90e 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -2931,6 +2931,21 @@
foo_event = [event for event in function_events if "foo" in event.name][0]
self.assertEqual(foo_event.count, 1)
+ def test_record_function_new_signatures(self):
+ # Test the new _record_function ops work
+ # Note: Remove once record_function uses these directly
+ x = torch.randn(10, 10)
+ with profile(use_kineto=kineto_available()) as p:
+ record = torch.ops.profiler._record_function_enter_new("bar", None)
+ try:
+ y = x * 2 + 4
+ finally:
+ torch.ops.profiler._record_function_exit(record)
+
+ function_events = p.function_events
+ foo_event = [event for event in function_events if "bar" in event.name][0]
+ self.assertEqual(foo_event.count, 1)
+
def test_profiler_aggregation_fake(self):
events = EventList()
id = [0]
diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py
index 6062c09..dc505fb 100644
--- a/torch/autograd/profiler_util.py
+++ b/torch/autograd/profiler_util.py
@@ -642,6 +642,7 @@
filtered_out_names = [
MEMORY_EVENT_NAME, # used only for the top-level memory events
"profiler::_record_function_enter",
+ "profiler::_record_function_enter_new",
"profiler::_record_function_exit",
"aten::is_leaf",
"aten::output_nr",
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 8499fd9..289f1e8 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -9,7 +9,6 @@
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <ATen/autocast_mode.h>
-#include <ATen/cpp_custom_type_hack.h>
#include <ATen/record_function.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/profiler_python.h>
@@ -21,6 +20,7 @@
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/autograd/python_variable.h>
+#include <torch/csrc/autograd/record_function_ops.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <c10/core/ScalarType.h>
@@ -241,7 +241,9 @@
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
m.def("_record_function_with_args_enter", [](const std::string& name, py::args args) {
- auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
+ using torch::autograd::profiler::PythonRecordFunction;
+ auto python_rec = c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
+ auto *rec = &python_rec->record;
if (rec->isActive()) {
if (rec->needsInputs()) {
auto iv_inputs = std::vector<c10::IValue>();
@@ -253,16 +255,19 @@
rec->before(name);
}
}
- return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
+ return torch::jit::toPyObject(std::move(python_rec));
});
// Ends the profiling scope created with record_function_with_param_enter.
- m.def("_record_function_with_args_exit", [](const at::Tensor& handle) {
- // We don't actually need to do anything with handle just need to persist the
- // lifetime until now.
- auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
- rec.end();
- });
+ m.def("_record_function_with_args_exit",
+ [](const py::object &obj) {
+ using torch::autograd::profiler::PythonRecordFunction;
+ auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj);
+
+ // We don't actually need to do anything with handle just need to persist the
+ // lifetime until now.
+ python_record->record.end();
+ });
m.def("_supported_activities", []() {
std::set<ActivityType> activities {ActivityType::CPU};
diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp
index 2cf427e..ad8bf33 100644
--- a/torch/csrc/autograd/record_function_ops.cpp
+++ b/torch/csrc/autograd/record_function_ops.cpp
@@ -1,8 +1,10 @@
+#include <torch/csrc/autograd/record_function_ops.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/record_function.h>
#include <ATen/ThreadLocalState.h>
-#include <torch/csrc/jit/runtime/custom_operator.h>
+#include <torch/library.h>
+#include <torch/csrc/jit/runtime/operator.h>
namespace caffe2 {
// Required for cpp_custom_type_hack to work
@@ -16,47 +18,68 @@
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
-at::Tensor record_function_enter(
+void record_function_enter(
+ const std::string& name,
+ const c10::optional<std::string>& args,
+ at::RecordFunction &rec) {
+ if (rec.isActive()) {
+ if (rec.needsInputs() && args.has_value()) {
+ rec.before(name, std::vector<c10::IValue>{c10::IValue{args.value()}});
+ } else {
+ rec.before(name);
+ }
+ }
+}
+
+// Legacy signature using cpp_custom_type_hack
+at::Tensor record_function_enter_legacy(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
- if (rec->isActive()) {
- if (rec->needsInputs() && args.has_value()) {
- rec->before(name, std::vector<c10::IValue>{c10::IValue{args.value()}});
- } else {
- rec->before(name);
- }
- }
+ record_function_enter(name, args, *rec);
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
}
+// New signature using custom_class
+c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
+ const std::string &name, const c10::optional<std::string> &args) {
+ auto rec = c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
+ record_function_enter(name, args, rec->record);
+ return rec;
+}
+
at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
return rec;
}
// Ends the profiling scope created with record_function_enter.
-void record_function_exit(const at::Tensor& handle) {
- // We don't actually need to do anything with handle just need to persist the
- // lifetime until now.
- auto& rec = getRecordFunctionFromTensor(handle);
+void record_function_exit(at::RecordFunction &rec) {
rec.end();
}
+// Legacy signature using cpp_custom_type_hack
+void record_function_exit_legacy(const at::Tensor &handle) {
+ // We don't actually need to do anything with handle just need to persist the
+ // lifetime until now.
+ auto& rec = getRecordFunctionFromTensor(handle);
+ record_function_exit(rec);
+}
+
+// New signature using custom_class
+void record_function_exit_new(const c10::intrusive_ptr<PythonRecordFunction> &record) {
+ record_function_exit(record->record);
+}
+
+template <typename Func>
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
- const at::Tensor& handle,
+ Func get_record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Profiling callback that ends the associated record_function
// and returns the value of the passed in future.
std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc =
- [handle](c10::ivalue::Future& fut) {
- TORCH_INTERNAL_ASSERT(
- handle.defined(),
- "Undefined RecordFunction handle. This can happen if the handle is "
- "not correctly persisted and is destroyed before the future is "
- "realized.");
-
- auto& rec = getRecordFunctionFromTensor(handle);
+ [get_record = std::move(get_record)](c10::ivalue::Future& fut) {
+ auto& rec = get_record();
rec.end();
// Note: this future is returned to the user to ensure that a call to wait()
// ensures that profiling callbacks have ran. To ensure that this is
@@ -67,36 +90,74 @@
};
// Define a future that completes after the profiling callbacks are run.
auto profiledFut = fut->then(at::wrapPropagateTLSState(
- futureProfilingFunc),
+ std::move(futureProfilingFunc)),
fut->elementType()
);
return profiledFut;
}
+// Legacy signature using cpp_custom_type_hack
+c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
+ const at::Tensor &handle,
+ const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
+ return _call_end_callbacks_on_fut(
+ [handle] () -> at::RecordFunction& {
+ TORCH_INTERNAL_ASSERT(
+ handle.defined(),
+ "Undefined RecordFunction handle. This can happen if the handle is "
+ "not correctly persisted and is destroyed before the future is "
+ "realized.");
+
+ return getRecordFunctionFromTensor(handle);
+ },
+ fut
+ );
+}
+
+// New signature using custom_class
+c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
+ const c10::intrusive_ptr<PythonRecordFunction> &record,
+ const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
+ return _call_end_callbacks_on_fut(
+ [record] () -> at::RecordFunction& { return record->record; }, fut);
+}
+
// Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler, m) {
- m.def("_record_function_enter(str name, str? args=None) -> Tensor", &record_function_enter);
- m.def("_record_function_exit", &record_function_exit);
-}
+ m.class_<PythonRecordFunction>("_RecordFunction");
-// Needed to register JIT operator in operator registry below
-c10::AliasAnalysisKind aliasAnalysisFromSchema() {
- return c10::AliasAnalysisKind::FROM_SCHEMA;
-}
+ m.def("_record_function_enter(str name, str? args=None) -> Tensor",
+ &record_function_enter_legacy);
+ m.def("_record_function_enter_new(str name, str? args=None) -> "
+ "__torch__.torch.classes.profiler._RecordFunction",
+ &record_function_enter_new);
+ m.def("_record_function_exit", &record_function_exit_legacy);
+ m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
-jit::RegisterOperators reg_fut_ops({
- jit::Operator(
+ torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
- auto profiledFut = _call_end_callbacks_on_fut(tensor, fut);
+ auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
},
- aliasAnalysisFromSchema()),
-});
+ c10::AliasAnalysisKind::FROM_SCHEMA));
+ torch::jit::registerOperator(torch::jit::Operator(
+ "profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
+ "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
+ [](c10::Stack &stack) {
+ // Pop inputs, which should be a future and a PythonRecordFunction
+ auto fut = torch::jit::pop(stack).toFuture();
+ auto tensor = torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
+ auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
+ // return future that completes when profiling callbacks have run.
+ torch::jit::push(stack, std::move(profiledFut));
+ },
+ c10::AliasAnalysisKind::FROM_SCHEMA));
+}
} // namespace profiler
} // namespace autograd
diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h
index 9042537..81cc584 100644
--- a/torch/csrc/autograd/record_function_ops.h
+++ b/torch/csrc/autograd/record_function_ops.h
@@ -1,17 +1,30 @@
#pragma once
#include <ATen/record_function.h>
#include <c10/util/Optional.h>
+#include <torch/custom_class.h>
namespace torch {
namespace autograd {
namespace profiler {
+
+struct PythonRecordFunction: public torch::CustomClassHolder {
+ at::RecordFunction record;
+
+ PythonRecordFunction(
+ at::RecordScope scope = at::RecordScope::FUNCTION,
+ bool pre_sampled = false)
+ : record(scope, pre_sampled)
+ {}
+};
+
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
-TORCH_API at::Tensor record_function_enter(const std::string& name, const c10::optional<std::string>& args = c10::nullopt);
+TORCH_API c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
+ const std::string &name, const c10::optional<std::string> &args = c10::nullopt);
// Schedules RecordFunction's end callbacks to be run on completion of a future.
-TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
- const at::Tensor& handle,
+TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
+ const c10::intrusive_ptr<PythonRecordFunction> &record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut);
} // namespace profiler
diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp
index 464a290..2d54c52 100644
--- a/torch/csrc/distributed/rpc/torchscript_functions.cpp
+++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp
@@ -21,10 +21,7 @@
std::vector<c10::IValue>& stack,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
- // This dummy tensor holds an at::RecordFunction when profiling is enabled.
- // This is because at::RecordFunction is not yet registered as a TorchScript
- // custom class (https://github.com/pytorch/pytorch/issues/35026)
- at::Tensor handle = at::zeros(1);
+ c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
!torch::distributed::rpc::RemoteProfilerManager::getInstance()
.isCurrentKeySet();
@@ -35,7 +32,7 @@
.qualifiedName(), /* name of torchscript function being run */
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
dstWorkerName);
- handle = torch::autograd::profiler::record_function_enter(rpcAsyncJitKey);
+ record = torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
auto& remoteProfilerManager =
torch::distributed::rpc::RemoteProfilerManager::getInstance();
remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
@@ -75,7 +72,7 @@
}));
if (shouldProfile) {
auto profiledFutPtr =
- torch::autograd::profiler::_call_end_callbacks_on_fut(handle, futPtr);
+ torch::autograd::profiler::_call_end_callbacks_on_fut_new(record, futPtr);
return profiledFutPtr;
}
return futPtr;
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 4676287..a922376 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -587,6 +587,16 @@
// python_ivalue.h
IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N);
+// Extract custom class registered with torchbind
+template <typename T>
+c10::intrusive_ptr<T> toCustomClass(py::handle obj) {
+ static_assert(
+ std::is_base_of<CustomClassHolder, T>::value, "T is not a CustomClass");
+ const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+ c10::IValue ivalue = toIValue(obj, type);
+ return std::move(ivalue).toCustomClass<T>();
+}
+
// Small wrapper around getting the type name string from Python to make
// types easier to interpret, e.g. give the structural type for a NamedTuple
inline std::string friendlyTypeName(py::handle obj) {
diff --git a/torch/fx/node.py b/torch/fx/node.py
index eefa731..4311094 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -26,7 +26,9 @@
]]
_side_effectful_functions: Set[Callable] = {
- torch._assert, torch.ops.profiler._record_function_enter,
+ torch._assert,
+ torch.ops.profiler._record_function_enter,
+ torch.ops.profiler._record_function_enter_new,
torch.ops.profiler._record_function_exit}
# this is fixed on master, WAR for 1.5
diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
index 5ba47c7..383e627 100644
--- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -153,8 +153,11 @@
t: Tensor = torch.ones(1)
with record_function(block) as rf:
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
+ # Extra operator call to avoid de-duplication of the next async call
+ # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
+ zero = torch.zeros_like(t)
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
- res = fut1.wait() + fut2.wait()
+ res = fut1.wait() + fut2.wait() + zero
return res
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index 1d69808..1d11a85 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -2381,6 +2381,24 @@
fut.wait()
@dist_init
+ def test_async_record_function_double_end_callbacks_new_signatures(self):
+ # Test the new _record_function ops work
+ # Note: Remove once record_function uses these directly
+ num_sleep_seconds = 1
+ if self.rank == 1:
+ with _profile() as pf:
+ try:
+ record = torch.ops.profiler._record_function_enter_new("foo", None)
+ fut = rpc.rpc_async(
+ worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+ )
+ torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+ finally:
+ torch.ops.profiler._record_function_exit(record)
+
+ fut.wait()
+
+ @dist_init
def test_async_record_function_cbs_jit_call(self):
if self.rank == 1:
with _profile() as pf: