blob: 20acd6f563d8250fffa496fc0bd42fa79f056d61 [file] [log] [blame]
#include <torch/csrc/profiler/api.h>
namespace torch {
namespace profiler {
namespace impl {
namespace {
enum ProfilerIValueIdx {
STATE = 0,
REPORT_INPUT_SHAPES,
PROFILE_MEMORY,
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
};
} // namespace
at::IValue ProfilerConfig::toIValue() const {
c10::impl::GenericList eventIValueList(at::AnyType::get());
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
eventIValueList.emplace_back(static_cast<int64_t>(state));
eventIValueList.emplace_back(report_input_shapes);
eventIValueList.emplace_back(profile_memory);
return eventIValueList;
}
ProfilerConfig ProfilerConfig::fromIValue(
const at::IValue& profilerConfigIValue) {
TORCH_INTERNAL_ASSERT(
profilerConfigIValue.isList(),
"Expected IValue to contain type c10::impl::GenericList");
auto ivalues = profilerConfigIValue.toList();
TORCH_INTERNAL_ASSERT(
ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
c10::str(
"Expected exactly ",
NUM_PROFILER_CFG_IVALUE_IDX,
" ivalues to resconstruct ProfilerConfig."));
return ProfilerConfig(
static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
}
namespace {
ProfilerThreadLocalStateBase* getProfilerTLSState() {
return static_cast<ProfilerThreadLocalStateBase*>(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
}
} // namespace
bool profilerEnabled() {
auto state_ptr = getProfilerTLSState();
return state_ptr &&
state_ptr->config().state !=
torch::profiler::impl::ProfilerState::Disabled;
}
TORCH_API ActiveProfilerType profilerType() {
auto state_ptr = getProfilerTLSState();
return state_ptr == nullptr
? ActiveProfilerType::NONE
: state_ptr->profilerType();
}
torch::profiler::impl::ProfilerConfig getProfilerConfig() {
auto state_ptr = getProfilerTLSState();
TORCH_CHECK(
state_ptr,
"Tried to access profiler config, but profiler is not enabled!");
return state_ptr->config();
}
CUDAStubs::~CUDAStubs() = default;
namespace {
struct DefaultCUDAStubs : public CUDAStubs {
void record(int* /*device*/, CUDAEventStub* /*event*/, int64_t* /*cpu_ns*/)
const override {
fail();
}
float elapsed(const CUDAEventStub* /*event*/, const CUDAEventStub* /*event2*/)
const override {
fail();
return 0.f;
}
void nvtxMarkA(const char* /*name*/) const override {
fail();
}
void nvtxRangePushA(const char* /*name*/) const override {
fail();
}
void nvtxRangePop() const override {
fail();
}
bool enabled() const override {
return false;
}
void onEachDevice(std::function<void(int)> /*op*/) const override {
fail();
}
void synchronize() const override {
fail();
}
~DefaultCUDAStubs() override = default;
private:
void fail() const {
AT_ERROR("CUDA used in profiler but not enabled.");
}
};
const DefaultCUDAStubs default_stubs;
constexpr const DefaultCUDAStubs* default_stubs_addr = &default_stubs;
// Constant initialization, so it is guaranteed to be initialized before
// static initialization calls which may invoke registerCUDAMethods
inline const CUDAStubs*& cuda_stubs() {
static const CUDAStubs* stubs_ =
static_cast<const CUDAStubs*>(default_stubs_addr);
return stubs_;
}
} // namespace
const CUDAStubs* cudaStubs() {
return cuda_stubs();
}
void registerCUDAMethods(CUDAStubs* stubs) {
cuda_stubs() = stubs;
}
} // namespace impl
} // namespace profiler
} // namespace torch