[pytorch][PR] [Profiler] Add EventFieldsVisitor
One source of complexity in profiler_kineto is that we do most things twice: once to set a field in `kineto_events_.back()`, and once for the metadata json. These have historically been chained, with the KinetoEvent used to populate the metadata fields. However this is hard to read and error prone, as we have one giant block of assignments followed by another giant block. It also means that logic about whether a field is present or not is duplicated.
This PR replaces this logic with a visitor that writes both together. E.g.
```
auto& dtypes = result_.get().inputs_.dtypes_;
if (!dtypes.empty()) {
kineto_event_.get().dtypes(dtypes);
out.emplace_back("Input type", dtypesToStr(dtypes));
}
```
Differential Revision: [D36070202](https://our.internmc.facebook.com/intern/diff/D36070202/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77691
Approved by: https://github.com/aaronenyeshi
diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp
index 1f9756f..db48e94 100644
--- a/torch/csrc/autograd/profiler_kineto.cpp
+++ b/torch/csrc/autograd/profiler_kineto.cpp
@@ -118,6 +118,11 @@
namespace {
using torch::profiler::impl::ProfilerThreadLocalStateBase;
using torch::profiler::impl::ActiveProfilerType;
+using torch::profiler::impl::Result;
+using torch::profiler::impl::kineto::annotation_t;
+using torch::profiler::impl::shapesToStr;
+using torch::profiler::impl::dtypesToStr;
+using torch::profiler::impl::stacksToStr;
struct MemoryEventData {
torch::profiler::impl::approx_time_t start_time;
@@ -132,6 +137,91 @@
};
static_assert(std::is_pod<MemoryEventData>::value, "Non-POD member of MemoryEventData.");
+struct EventFieldsVisitor {
+ EventFieldsVisitor(const Result& result, KinetoEvent& kineto_event)
+ : result_{result}, kineto_event_{kineto_event} {
+ handleJIT(result_.get().jit_stack_, result_.get().jit_modules_);
+ c10::visit(*this, result.event_);
+ }
+
+ void operator()(const torch::profiler::impl::OpEvent& op_event) {
+ kineto_event_.get()
+ .endThreadId(op_event.end_thread_id_)
+ .scope(op_event.record_function_scope_)
+ .setAsync(op_event.is_async_)
+ .debugHandle(op_event.debug_handle_);
+
+ auto& shapes = result_.get().inputs_.shapes_;
+ if (!shapes.empty()) {
+ kineto_event_.get().shapes(shapes);
+ annotations_.emplace_back("Input Dims", shapesToStr(shapes));
+ }
+
+ auto& dtypes = result_.get().inputs_.dtypes_;
+ if (!dtypes.empty()) {
+ kineto_event_.get().dtypes(dtypes);
+ annotations_.emplace_back("Input type", dtypesToStr(dtypes));
+ }
+
+ if (!result_.get().extra_args_.empty()) {
+ kineto_event_.get().flops(
+ computeFlops(result_.get().name(), result_.get().extra_args_));
+ }
+ kineto_event_.get().cuda_event_start_ =
+ result_.get().gpu_fallback_.cuda_event_start_;
+ kineto_event_.get().cuda_event_end_ =
+ result_.get().gpu_fallback_.cuda_event_end_;
+
+ // add information about an associated forward op, if a sequence number
+ // is available (e.g. during training)
+ if (op_event.sequence_number_ >= 0) {
+ kineto_event_.get()
+ .sequenceNr(op_event.sequence_number_)
+ .fwdThreadId(op_event.forward_thread_id_);
+ annotations_.emplace_back(
+ "Fwd thread id", std::to_string(op_event.forward_thread_id_));
+ annotations_.emplace_back(
+ "Sequence number", std::to_string(op_event.sequence_number_));
+ }
+ }
+
+ void operator()(const torch::profiler::impl::BackendEvent& backend_event) {
+ kineto_event_.get()
+ .endThreadId(result_.get().start_tid_)
+ .scope(backend_event.record_function_scope_)
+ .debugHandle(backend_event.debug_handle_)
+ .backend(backend_event.backend_);
+
+ if (!backend_event.backend_.empty()) {
+ annotations_.emplace_back(
+ "Backend", "\"" + backend_event.backend_ + "\"");
+ }
+ }
+
+ void handleJIT(
+ const std::vector<std::string>& jit_stack,
+ const std::vector<std::string>& jit_modules) {
+ if (!jit_stack.empty()) {
+ // NB: This is only for the JIT stack. The python stack (if applicable)
+ // is constructed later.
+ kineto_event_.get().stack(jit_stack);
+ annotations_.emplace_back(
+ "Call stack", torch::profiler::impl::stacksToStr(jit_stack, ";"));
+ }
+
+ if (!jit_modules.empty()) {
+ kineto_event_.get().moduleHierarchy(jit_modules);
+ annotations_.emplace_back(
+ "Module Hierarchy",
+ torch::profiler::impl::stacksToStr(jit_modules, "."));
+ }
+ }
+
+ std::reference_wrapper<const Result> result_;
+ std::reference_wrapper<KinetoEvent> kineto_event_;
+ annotation_t annotations_;
+};
+
auto getAnnotations(const MemoryEventData& event) {
torch::profiler::impl::kineto::annotation_t out{
{"Device Type", std::to_string((int8_t)event.device_type)},
@@ -201,13 +291,11 @@
}
}
- const std::function<void(std::vector<KinetoEvent>&)>&
- getEventPostProcessingCallback() const {
+ const post_process_t& getEventPostProcessingCallback() const {
return event_post_process_cb_;
}
- void setEventPostProcessingCallback(
- std::function<void(std::vector<KinetoEvent>&)>&& cb) {
+ void setEventPostProcessingCallback(post_process_t&& cb) {
event_post_process_cb_ = std::move(cb);
}
@@ -215,12 +303,6 @@
auto end_time = getTimeUs();
materializeOpEvents();
- // Call events post processing callback before finalizing trace, if there is
- // one.
- if (getEventPostProcessingCallback()) {
- getEventPostProcessingCallback()(kineto_events_);
- }
-
finalizeCPUTrace(cpu_trace_.get());
{
std::lock_guard<std::mutex> guard(state_mutex_);
@@ -263,7 +345,7 @@
}
memory_events_.clear();
- for (const auto& e : record_queue_.getRecords(converter)) {
+ for (auto& e : record_queue_.getRecords(converter)) {
// `take_data` handles time conversion.
int64_t start_us = e.start_time_us_;
int64_t end_us = e.end_time_us_;
@@ -274,14 +356,14 @@
continue;
}
- cpu_trace_.addCPUActivity(
- e.name(),
- e.kinetoType(),
- e.kineto_info_,
- e.correlation_id(),
- start_us,
- end_us,
- /*annotations=*/{});
+ // Call events post processing callback before finalizing trace, if there
+ // is one.
+ if (getEventPostProcessingCallback()) {
+ getEventPostProcessingCallback()(
+ c10::visit([](const auto& i) { return i.debug_handle_; }, e.event_),
+ e.jit_stack_,
+ e.jit_modules_);
+ }
kineto_events_.emplace_back();
kineto_events_.back()
@@ -292,50 +374,18 @@
.deviceType(c10::DeviceType::CPU)
.startThreadId(e.start_tid_);
- c10::visit(
- c10::overloaded(
- [&](const torch::profiler::impl::OpEvent& op_event) {
- kineto_events_.back()
- .endThreadId(op_event.end_thread_id_)
- .sequenceNr(op_event.sequence_number_)
- .fwdThreadId(op_event.forward_thread_id_)
- .scope(op_event.record_function_scope_)
- .setAsync(op_event.is_async_)
- .debugHandle(op_event.debug_handle_);
- },
- [&](const torch::profiler::impl::BackendEvent& backend_event) {
- kineto_events_.back()
- .endThreadId(e.start_tid_)
- .scope(backend_event.record_function_scope_)
- .debugHandle(backend_event.debug_handle_)
- .backend(backend_event.backend_);
- }),
- e.event_);
+ // NB: also sets fields on `kineto_events_.back()`.
+ auto annotations =
+ EventFieldsVisitor(e, kineto_events_.back()).annotations_;
- if (!e.inputs_.shapes_.empty()) {
- kineto_events_.back().shapes(e.inputs_.shapes_);
- }
-
- if (!e.inputs_.dtypes_.empty()) {
- kineto_events_.back().dtypes(e.inputs_.dtypes_);
- }
-
- if (!e.jit_stack_.empty()) {
- kineto_events_.back().stack(e.jit_stack_);
- }
-
- if (!e.jit_modules_.empty()) {
- kineto_events_.back().moduleHierarchy(e.jit_modules_);
- }
-
- if (!e.extra_args_.empty()) {
- kineto_events_.back().flops(
- computeFlops(e.name(), e.extra_args_));
- }
- kineto_events_.back().cuda_event_start_ =
- e.gpu_fallback_.cuda_event_start_;
- kineto_events_.back().cuda_event_end_ =
- e.gpu_fallback_.cuda_event_end_;
+ cpu_trace_.addCPUActivity(
+ e.name(),
+ e.kinetoType(),
+ e.kineto_info_,
+ e.correlation_id(),
+ start_us,
+ end_us,
+ annotations);
}
}
@@ -348,6 +398,14 @@
// startThreadId_seqNum to pointer of activity.
// Low-16bits of startThreadId and low-48bits seqNum are concatenated into
// one uint64_t variable as key.
+
+ // From the time being, we need disable the forward/backward correlation feature to
+ // workaround the crash bug.
+ // TODO: by Mike Guo
+ // reenable the forward/backward correlation when kineto fix the following raw pointer
+ // GenericTraceActivity.flow.linkedActivity
+
+ /*
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>
tidSeq2activity;
@@ -355,44 +413,14 @@
auto& kineto_event = kineto_events_[idx];
auto& activity = cpu_trace->activities[idx];
- if (kineto_event.hasShapes()) {
- activity.addMetadata("Input Dims", torch::profiler::impl::shapesToStr(kineto_event.shapes()));
- }
- if (kineto_event.hasStack()) {
- // NB: This is only for the JIT stack. The python stack (if applicable)
- // is constructed later.
- activity.addMetadata(
- "Call stack", torch::profiler::impl::stacksToStr(kineto_event.stack(), ";"));
- }
- if (kineto_event.hasModuleHierarchy()) {
- activity.addMetadata(
- "Module Hierarchy",
- torch::profiler::impl::stacksToStr(kineto_event.moduleHierarchy(), "."));
- }
- if (kineto_event.hasTypes()) {
- activity.addMetadata("Input type", torch::profiler::impl::dtypesToStr(kineto_event.dtypes()));
- }
- if (!kineto_event.backend().empty()) {
- activity.addMetadata("Backend", "\"" + kineto_event.backend() + "\"");
- }
-
// add information about an associated forward op, if a sequence number
// is available (e.g. during training)
if (kineto_event.sequenceNr() >= 0) {
- activity.addMetadata(
- "Fwd thread id", std::to_string(kineto_event.fwdThreadId()));
- activity.addMetadata(
- "Sequence number", std::to_string(kineto_event.sequenceNr()));
-
- // From the time being, we need disable the forward/backward correlation feature to
- // workaround the crash bug.
- // TODO: by Mike Guo
- // reenable the forward/backward correlation when kineto fix the following raw pointer
- // GenericTraceActivity.flow.linkedActivity
- // generateForwardBackwardLink(
- // kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
+ generateForwardBackwardLink(
+ kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
}
}
+ */
addPythonEvents(cpu_trace);
}
@@ -610,7 +638,7 @@
torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
std::vector<KinetoEvent> kineto_events_;
// Optional, if event post-processing is enabled.
- std::function<void(std::vector<KinetoEvent>&)> event_post_process_cb_;
+ post_process_t event_post_process_cb_;
};
static std::unique_ptr<KinetoThreadLocalState> globalStatePtr;
@@ -750,7 +778,7 @@
void enableProfilerWithEventPostProcess(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
- std::function<void(std::vector<KinetoEvent>&)>&& cb,
+ post_process_t&& cb,
const std::unordered_set<at::RecordScope>& scopes) {
TORCH_CHECK(
config.state != ProfilerState::NVTX,
diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h
index c7b130c..c980096 100644
--- a/torch/csrc/autograd/profiler_kineto.h
+++ b/torch/csrc/autograd/profiler_kineto.h
@@ -344,10 +344,14 @@
* callback, via enableProfilerWithEventPostProcess, that takes these debug handles
* and generates stack trace and module hierarchy information, once profiling is done.
*/
+using post_process_t = std::function<void(
+ /*debug_handle */ int64_t,
+ /*jit_stack */ std::vector<std::string>&,
+ /*jit_modules */ std::vector<std::string>&)>;
TORCH_API void enableProfilerWithEventPostProcess(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
- std::function<void(std::vector<KinetoEvent>&)>&& cb,
+ post_process_t&& cb,
const std::unordered_set<at::RecordScope>& scopes = {});
TORCH_API std::unique_ptr<ProfilerResult> disableProfiler();
diff --git a/torch/csrc/jit/mobile/profiler_edge.cpp b/torch/csrc/jit/mobile/profiler_edge.cpp
index 40a6470..7236558 100644
--- a/torch/csrc/jit/mobile/profiler_edge.cpp
+++ b/torch/csrc/jit/mobile/profiler_edge.cpp
@@ -1,4 +1,5 @@
#include <c10/util/Exception.h>
+#include <c10/util/overloaded.h>
#include <torch/csrc/jit/mobile/profiler_edge.h>
#include <string>
#include <vector>
@@ -28,33 +29,26 @@
torch::autograd::profiler::prepareProfiler(
config, {torch::autograd::profiler::ActivityType::CPU});
if (with_modules || with_stack) {
- auto post_processing =
- [this, with_stack, with_modules](
- std::vector<torch::autograd::profiler::KinetoEvent>& events) {
- std::string no_debug_info(
- "Model was not saved with debug information");
- for (auto& e : events) {
- if (with_modules) {
- // Since KinetoEvents's module hierarchy takes vector of strings
- // we just construct a temporary vector using one string element
- if (this->m_.hasDebugHandles()) {
- e.moduleHierarchy(std::vector<std::string>(
- {this->m_.getModuleHierarchy(e.debugHandle())}));
- } else {
- e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
- }
- } else if (with_stack) {
- // Since KinetoEvents's stack trace takes vector of strings we
- // just construct a temporary vector using one string element
- if (this->m_.hasDebugHandles()) {
- e.stack(std::vector<std::string>(
- {this->m_.getCallStack(e.debugHandle())}));
- } else {
- e.stack(std::vector<std::string>({no_debug_info}));
- }
- }
- }
- };
+ auto post_processing = [this, with_stack, with_modules](
+ int64_t debug_handle,
+ std::vector<std::string>& jit_stack,
+ std::vector<std::string>& jit_modules) {
+ std::string no_debug_info("Model was not saved with debug information");
+ if (with_modules) {
+ // Since KinetoEvents's module hierarchy takes vector of strings
+ // we just construct a temporary vector using one string element
+ jit_modules = std::vector<std::string>(
+ {this->m_.hasDebugHandles()
+ ? this->m_.getModuleHierarchy(debug_handle)
+ : no_debug_info});
+ } else if (with_stack) {
+ // Since KinetoEvents's stack trace takes vector of strings we
+ // just construct a temporary vector using one string element
+ jit_stack = std::vector<std::string>(
+ {this->m_.hasDebugHandles() ? this->m_.getCallStack(debug_handle)
+ : no_debug_info});
+ }
+ };
torch::autograd::profiler::enableProfilerWithEventPostProcess(
config,
{torch::autograd::profiler::ActivityType::CPU},