blob: ad861959f0325605a7bf51673495416c7818eba6 [file] [log] [blame]
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/function.h"
namespace torch { namespace autograd { namespace profiler {
ProfilerState state = ProfilerState::Disabled;
uint32_t next_thread_id = 0;
std::mutex all_event_lists_mutex;
std::list<std::shared_ptr<RangeEventList>> all_event_lists;
thread_local std::shared_ptr<RangeEventList> event_list;
thread_local int32_t thread_id;
void RecordFunction::pushFunctionRange(Function* fn) {
pushRange(fn->name());
}
#ifdef USE_CUDA
static void onEachDevice(std::function<void(int)> op) {
AutoGPU gpu_guard;
int count;
TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
for(int i = 0; i < count; i++) {
gpu_guard.setDevice(i);
op(i);
}
}
#endif
void enableProfiler(ProfilerState new_state) {
TORCH_ASSERT(new_state != ProfilerState::Disabled);
#ifndef USE_CUDA
if (new_state == ProfilerState::NVTX)
throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA");
#endif
if (state != ProfilerState::Disabled && new_state != state) {
throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running");
}
state = new_state;
#ifdef USE_CUDA
if(state == ProfilerState::CUDA) {
// event recording appears to have some startup overhead, so we need to
// to generate some dummy events first before recording syncrhonization events
for(int i = 0; i < 5; i++) {
onEachDevice([](int d) {
mark("__cuda_startup");
cudaDeviceSynchronize();
});
}
// cuda events must be on the same device, so we need a start event recorded
// for each gpu. we then use this event to synchronize time on the GPU
// with the CPU clock.
onEachDevice([](int d) {
mark("__cuda_start_event");
});
}
#endif
mark("__start_profile", false);
}
thread_event_lists disableProfiler() {
if (state == ProfilerState::Disabled) {
throw std::runtime_error("can't disable profiler when it's not running");
}
ProfilerState old_state = state;
mark("__stop_profile");
state = ProfilerState::Disabled;
if (old_state == ProfilerState::NVTX) {
return thread_event_lists();
} else {
thread_event_lists result;
std::lock_guard<std::mutex> guard(all_event_lists_mutex);
for (auto it = all_event_lists.begin(); it != all_event_lists.end();) {
auto & list = *it;
result.emplace_back(list->consolidate());
// GC lists that are not held by any threads
if (list.use_count() == 1) {
auto current_it = it;
++it;
all_event_lists.erase(current_it);
} else {
++it;
}
}
return result;
}
}
}}}