| #include <cstring> |
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <torch/csrc/autograd/profiler_kineto.h> |
| |
| #include <c10/macros/Export.h> |
| #include <c10/util/ApproximateClock.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/flat_hash_map.h> |
| #include <c10/util/irange.h> |
| #include <c10/util/overloaded.h> |
| |
| #include <torch/csrc/profiler/api.h> |
| #include <torch/csrc/profiler/collection.h> |
| #include <torch/csrc/profiler/containers.h> |
| #include <torch/csrc/profiler/events.h> |
| #include <torch/csrc/profiler/kineto_shim.h> |
| #include <torch/csrc/profiler/orchestration/observer.h> |
| #include <torch/csrc/profiler/perf.h> |
| #include <torch/csrc/profiler/standalone/itt_observer.h> |
| #include <torch/csrc/profiler/standalone/nvtx_observer.h> |
| #include <torch/csrc/profiler/standalone/privateuse1_observer.h> |
| #include <torch/csrc/profiler/util.h> |
| |
| #include <ATen/Context.h> |
| |
| #include <stdexcept> |
| #include <utility> |
| |
| #ifdef USE_KINETO |
| #include <ApproximateClock.h> |
| #include <libkineto.h> |
| #include <time_since_epoch.h> |
| |
| #ifndef _MSC_VER |
| // TODO: TO be removed, once this properly works from libkineto |
| // Literal copy-n-paste from third_party/kineto/libkineto/src/WeakSymbols.cpp |
| extern "C" { |
| // This function is needed to avoid superfluous dependency on GNU OpenMP library |
| // when cuPTI is linked statically For more details see |
| // https://github.com/pytorch/pytorch/issues/51026 |
| __attribute__((weak)) int acc_get_device_type(); |
| __attribute__((weak)) int acc_get_device_type() { |
| throw std::runtime_error( |
| "Dummy implementation of acc_get_device_type is not supposed to be called!"); |
| } |
| } // extern "C" |
| #endif // _MSC_VER |
| #endif // USE_KINETO |
| |
| namespace torch { |
| namespace autograd::profiler { |
| |
| namespace { |
| inline int64_t getTimeNs() { |
| #ifdef USE_KINETO |
| return libkineto::timeSinceEpoch(std::chrono::system_clock::now()); |
| #else |
| return c10::getTime(); |
| #endif // USE_KINETO |
| } |
| |
| using torch::profiler::impl::ActiveProfilerType; |
| using torch::profiler::impl::EventType; |
| using torch::profiler::impl::ExtraFields; |
| using torch::profiler::impl::get_record_concrete_inputs_enabled; |
| using torch::profiler::impl::ivalueListToStr; |
| using torch::profiler::impl::ivalueToStr; |
| using torch::profiler::impl::op_input_t; |
| using torch::profiler::impl::ProfilerStateBase; |
| using torch::profiler::impl::PyExtraFieldsBase; |
| using torch::profiler::impl::Result; |
| using torch::profiler::impl::shape; |
| using torch::profiler::impl::shapesToStr; |
| using torch::profiler::impl::stacksToStr; |
| using torch::profiler::impl::strListToStr; |
| using torch::profiler::impl::TensorMetadata; |
| using torch::profiler::impl::variantShapesToStr; |
| |
| struct OpArgData { |
| bool hasData; |
| std::vector<shape> shapes; |
| std::vector<std::string> dtypes; |
| std::vector<c10::IValue> concreteInputs; |
| std::vector<std::vector<int64_t>> shapesForKinetoEvent; |
| std::vector<shape> strides; |
| }; |
| |
| auto parseArgData( |
| const std::vector<op_input_t>& input_shapes, |
| const std::vector<op_input_t>& concreteInputs) { |
| if (input_shapes.empty()) { |
| return OpArgData{false, {}, {}, {}, {}, {}}; |
| } |
| |
| std::vector<shape> shapes(input_shapes.size()); |
| std::vector<shape> strides(input_shapes.size()); |
| std::vector<std::vector<int64_t>> shapesForKinetoEvent(input_shapes.size()); |
| |
| std::vector<std::string> dtypes(input_shapes.size()); |
| std::vector<c10::IValue> concrete_inputs_list; |
| |
| for (const auto& i : c10::irange(input_shapes.size())) { |
| std::visit( |
| c10::overloaded( |
| [&](const TensorMetadata& t) { |
| shapes[i] = t.sizes_; |
| shapesForKinetoEvent[i] = t.sizes_; |
| dtypes[i] = std::string(scalarTypeToTypeMeta(t.dtype_).name()); |
| strides[i] = t.strides_; |
| }, |
| [&](const std::vector<TensorMetadata>& l) { |
| std::vector<std::vector<int64_t>> shape; |
| shape.reserve(l.size()); |
| std::vector<std::vector<int64_t>> stride; |
| stride.reserve(l.size()); |
| for (const auto& t : l) { |
| shape.emplace_back(t.sizes_); |
| stride.emplace_back(t.strides_); |
| } |
| shapes[i] = shape; |
| strides[i] = stride; |
| dtypes[i] = "TensorList"; |
| }, |
| [&](const c10::IValue&) { dtypes[i] = "Scalar"; }, |
| [&](const auto&) {}), |
| input_shapes[i]); |
| } |
| |
| // If we recorded concrete inputs, then parse them |
| if (input_shapes.size() == concreteInputs.size() && !concreteInputs.empty()) { |
| concrete_inputs_list.resize(input_shapes.size()); |
| |
| for (const auto& i : c10::irange(input_shapes.size())) { |
| std::visit( |
| c10::overloaded( |
| [&](const c10::IValue& val) { concrete_inputs_list[i] = val; }, |
| [&](const auto&) {}), |
| input_shapes[i]); |
| std::visit( |
| c10::overloaded( |
| [&](const c10::IValue& val) { |
| concrete_inputs_list[i] = val; |
| dtypes[i] = "ScalarList"; |
| }, |
| [&](const auto&) {}), |
| concreteInputs[i]); |
| } |
| } |
| |
| return OpArgData{ |
| true, |
| shapes, |
| dtypes, |
| concrete_inputs_list, |
| shapesForKinetoEvent, |
| strides}; |
| } |
| |
| struct MetadataBase { |
| /* implicit */ MetadataBase(const std::shared_ptr<Result>& result) |
| : kinetoActivity_{result->kineto_activity_} { |
| if (std::holds_alternative<ExtraFields<EventType::Kineto>>( |
| result->extra_fields_)) { |
| // In order to add metadata we have to downcast from |
| // `libkineto::ITraceActivity` to `libkineto::GenericTraceActivity`. We |
| // know that all activities provided by PyTorch are of the correct type, |
| // however Kineto profilers can (and do) add events that inherit directly |
| // from ITraceActivity. As a result, any Result which was constructed from |
| // an event that Kineto provided is unsafe to cast. |
| if (!(SOFT_ASSERT(!hasKinetoActivity()))) { |
| result->kineto_activity_ = nullptr; |
| } |
| kinetoActivity_ = result->kineto_activity_; |
| } |
| } |
| |
| void addMetadata(const std::string& key, const std::string& value) { |
| if (kinetoActivity_ && !value.empty() && value != "\"\"") { |
| torch::profiler::impl::kineto::addMetadata( |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| const_cast<torch::profiler::impl::kineto::activity_t*>( |
| kinetoActivity_), |
| key, |
| value); |
| } |
| } |
| |
| bool hasKinetoActivity() const { |
| return kinetoActivity_ != nullptr; |
| } |
| |
| private: |
| const torch::profiler::impl::kineto::activity_t* kinetoActivity_{nullptr}; |
| }; |
| |
| struct AddTensorboardFields : public MetadataBase { |
| AddTensorboardFields( |
| const std::shared_ptr<Result>& result, |
| KinetoEvent& kineto_event) |
| : MetadataBase(result) { |
| result->visit(*this); |
| const auto module_hierarchy = kineto_event.moduleHierarchy(); |
| addMetadata("Module Hierarchy", stacksToStr(module_hierarchy.vec(), ".")); |
| addMetadata("Call stack", stacksToStr(kineto_event.stack().vec(), ";")); |
| |
| result->visit_if_base<PyExtraFieldsBase>([&, this](const auto& i) -> void { |
| this->addMetadata("Python id", std::to_string(i.id_)); |
| |
| std::optional<std::string> parent_id; |
| std::shared_ptr<Result> parent = result->parent_.lock(); |
| while (parent && !parent_id.has_value()) { |
| parent->visit_if_base<PyExtraFieldsBase>( |
| [&](const auto& j) { parent_id = std::to_string(j.id_); }); |
| parent = parent->parent_.lock(); |
| } |
| this->addMetadata("Python parent id", parent_id.value_or("null")); |
| }); |
| } |
| |
| void operator()(const ExtraFields<EventType::PyCall>& py_call) { |
| if (py_call.module_.has_value()) { |
| addMetadata("Python module id", std::to_string(py_call.module_->id_)); |
| } |
| } |
| |
| template <typename T> |
| void operator()(const T&) {} |
| }; |
| |
| struct AddGenericMetadata : public MetadataBase { |
| AddGenericMetadata( |
| std::shared_ptr<Result>& result, |
| const torch::profiler::impl::ProfilerConfig* config) |
| : MetadataBase(result), config_(config) { |
| result->visit(*this); |
| if (config->experimental_config.verbose) { |
| result->visit_if_base<PyExtraFieldsBase>( |
| [&, this](const auto& i) -> void { |
| this->addMetadata("Python thread", std::to_string(i.python_tid_)); |
| }); |
| } |
| } |
| |
| void operator()(ExtraFields<EventType::TorchOp>& op_event) { |
| const auto arg_data = |
| parseArgData(op_event.inputs_, op_event.concrete_inputs_); |
| |
| if (arg_data.hasData) { |
| if (get_record_concrete_inputs_enabled()) { |
| addMetadata("Input Dims", variantShapesToStr(arg_data.shapes)); |
| addMetadata("Input Strides", variantShapesToStr(arg_data.strides)); |
| } else { |
| addMetadata("Input Dims", shapesToStr(arg_data.shapesForKinetoEvent)); |
| } |
| addMetadata("Input type", strListToStr(arg_data.dtypes)); |
| if (!arg_data.concreteInputs.empty()) { |
| addMetadata( |
| "Concrete Inputs", ivalueListToStr(arg_data.concreteInputs)); |
| } |
| } |
| |
| // Add metadata for kwinputs if exist |
| for (const auto& [key, val] : op_event.kwinputs_) { |
| addMetadata(key, ivalueToStr(val)); |
| } |
| // Add extra metadata if any |
| for (const auto& [key, val] : op_event.extra_meta_) { |
| addMetadata(key, val); |
| } |
| |
| if (config_ && !config_->experimental_config.performance_events.empty()) { |
| auto& event_names = config_->experimental_config.performance_events; |
| for (const auto i : c10::irange(op_event.perf_event_counters_->size())) { |
| addMetadata( |
| event_names[i], |
| std::to_string((*op_event.perf_event_counters_)[i])); |
| } |
| } |
| |
| // add information about an associated forward op, if a sequence number |
| // is available (e.g. during training) |
| if (op_event.sequence_number_ >= 0) { |
| addMetadata("Fwd thread id", std::to_string(op_event.forward_tid_)); |
| addMetadata("Sequence number", std::to_string(op_event.sequence_number_)); |
| } |
| addMetadata( |
| "Record function id", std::to_string(op_event.record_function_id_)); |
| } |
| |
| void operator()(ExtraFields<EventType::Backend>& backend_event) { |
| if (!backend_event.backend_.empty()) { |
| addMetadata("Backend", "\"" + backend_event.backend_ + "\""); |
| } |
| } |
| |
| void operator()(const ExtraFields<EventType::Allocation>& alloc) { |
| addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_)); |
| addMetadata("Device Id", std::to_string(alloc.device_index_)); |
| addMetadata("Addr", std::to_string(reinterpret_cast<intptr_t>(alloc.ptr_))); |
| addMetadata("Bytes", std::to_string(alloc.alloc_size_)); |
| addMetadata("Total Allocated", std::to_string(alloc.total_allocated_)); |
| addMetadata("Total Reserved", std::to_string(alloc.total_reserved_)); |
| } |
| |
| void operator()(const ExtraFields<EventType::OutOfMemory>& alloc) { |
| addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_)); |
| addMetadata("Device Id", std::to_string(alloc.device_index_)); |
| addMetadata("Bytes", std::to_string(alloc.alloc_size_)); |
| addMetadata("Total Allocated", std::to_string(alloc.total_allocated_)); |
| addMetadata("Total Reserved", std::to_string(alloc.total_reserved_)); |
| } |
| |
| template <typename T> |
| void operator()(const T&) {} |
| |
| private: |
| /* To get names of the performance events */ |
| const torch::profiler::impl::ProfilerConfig* config_; |
| }; |
| |
| struct KinetoThreadLocalState : public ProfilerStateBase { |
| explicit KinetoThreadLocalState( |
| const ProfilerConfig& config, |
| std::set<torch::profiler::impl::ActivityType> activities) |
| : ProfilerStateBase(config), |
| startTime(getTimeNs()), |
| recordQueue(config, std::move(activities)) {} |
| ~KinetoThreadLocalState() override = default; |
| |
| static KinetoThreadLocalState* get(bool global) { |
| auto* state = ProfilerStateBase::get(/*global=*/global); |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
| state == nullptr || |
| state->profilerType() == ActiveProfilerType::KINETO); |
| return static_cast<KinetoThreadLocalState*>(state); |
| } |
| |
| ActiveProfilerType profilerType() override { |
| return ActiveProfilerType::KINETO; |
| } |
| |
| void reportVulkanEventToProfiler(torch::profiler::impl::vulkan_id_t id) { |
| if (!config_.disabled()) { |
| recordQueue.getSubqueue()->emplace_vulkan_event( |
| c10::getApproximateTime(), id); |
| } |
| } |
| |
| void reportMemoryUsage( |
| void* ptr, |
| int64_t alloc_size, |
| size_t total_allocated, |
| size_t total_reserved, |
| c10::Device device) override { |
| if (config_.profile_memory && !config_.disabled()) { |
| recordQueue.getSubqueue()->emplace_allocation_event( |
| c10::getApproximateTime(), |
| ptr, |
| alloc_size, |
| total_allocated, |
| total_reserved, |
| device.type(), |
| device.index()); |
| } |
| } |
| |
| void reportOutOfMemory( |
| int64_t alloc_size, |
| size_t total_allocated, |
| size_t total_reserved, |
| c10::Device device) override { |
| if (config_.profile_memory && !config_.disabled()) { |
| recordQueue.getSubqueue()->emplace_ooms_event( |
| c10::getApproximateTime(), |
| alloc_size, |
| total_allocated, |
| total_reserved, |
| device.type(), |
| device.index()); |
| } |
| } |
| |
| void setEventPostProcessingCallback(post_process_t&& cb) { |
| eventPostProcessCb = std::move(cb); |
| } |
| |
| void pausePython() { |
| recordQueue.stop(); |
| } |
| |
| void resumePython() { |
| recordQueue.restart(); |
| } |
| |
| std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper> |
| finalizeTrace() { |
| auto end_time = getTimeNs(); |
| recordQueue.stop(); |
| |
| std::lock_guard<std::mutex> guard(state_mutex_); |
| auto converter = clockConverter.makeConverter(); |
| #ifdef USE_KINETO |
| libkineto::get_time_converter() = converter; |
| #endif |
| auto records_and_trace = |
| recordQueue.getRecords(std::move(converter), startTime, end_time); |
| |
| materializeOpEvents(records_and_trace.first); |
| |
| // `kinetoEvents` does not include Python events. Instead it exposes them |
| // via the `stacks` property. |
| kinetoEvents.erase( |
| std::remove_if( |
| kinetoEvents.begin(), |
| kinetoEvents.end(), |
| [](const auto& i) { return i.isPythonFunction(); }), |
| kinetoEvents.end()); |
| |
| return std::move(records_and_trace.second); |
| } |
| |
| template <typename T> |
| void invokeCallback(T& t) { |
| if (eventPostProcessCb) { |
| eventPostProcessCb(t.debug_handle_, t.jit_stack_, t.jit_modules_); |
| } |
| } |
| |
| void materializeOpEvents(std::vector<std::shared_ptr<Result>>& events) { |
| for (auto& e : events) { |
| if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) { |
| eventTree.push_back(e); |
| } |
| |
| if (e->finished_) { |
| e->visit(c10::overloaded( |
| [this](ExtraFields<EventType::TorchOp>& i) { invokeCallback(i); }, |
| [this](ExtraFields<EventType::Backend>& i) { invokeCallback(i); }, |
| [](auto&) {})); |
| |
| kinetoEvents.emplace_back(e, config_.experimental_config.verbose); |
| AddTensorboardFields add_tb(e, kinetoEvents.back()); |
| AddGenericMetadata add_generic(e, &config_); |
| |
| // It is not safe to use the activity after post processing. |
| e->kineto_activity_ = nullptr; |
| } |
| } |
| } |
| |
| uint64_t startTime; |
| c10::ApproximateClockToUnixTimeConverter clockConverter; |
| torch::profiler::impl::RecordQueue recordQueue; |
| std::vector<KinetoEvent> kinetoEvents; |
| std::vector<experimental_event_t> eventTree; |
| // Optional, if event post-processing is enabled. |
| post_process_t eventPostProcessCb; |
| }; |
| |
| template <bool use_global_state_ptr = false> |
| std::unique_ptr<at::ObserverContext> onFunctionEnter( |
| const at::RecordFunction& fn) { |
| auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr); |
| if (!state_ptr) { |
| return nullptr; |
| } |
| return state_ptr->recordQueue.getSubqueue()->begin_op(fn); |
| } |
| |
| // @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter |
| template <bool use_global_state_ptr = false> |
| void onFunctionExit( |
| const at::RecordFunction& fn, |
| at::ObserverContext* ctx_ptr) { |
| auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr); |
| if (!state_ptr) { |
| return; |
| } |
| const auto& config = state_ptr->config(); |
| auto* kineto_ctx_ptr = |
| static_cast<torch::profiler::impl::KinetoObserverContext*>(ctx_ptr); |
| TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); |
| kineto_ctx_ptr->event_->end_time_ = c10::getApproximateTime(); |
| if (!config.experimental_config.performance_events.empty()) { |
| state_ptr->recordQueue.getSubqueue()->disable_perf_profiler( |
| *kineto_ctx_ptr->event_->counters_); |
| } |
| kineto_ctx_ptr->event_->basic_fields_.end_tid_ = |
| at::RecordFunction::currentThreadId(); |
| if (config.state == ProfilerState::KINETO_GPU_FALLBACK) { |
| try { |
| auto fallback = kineto_ctx_ptr->fallback_; |
| TORCH_INTERNAL_ASSERT(fallback != nullptr); |
| torch::profiler::impl::cudaStubs()->record( |
| nullptr, &fallback->device_event_end_, nullptr); |
| } catch (const std::exception& e) { |
| LOG(WARNING) << "Failed to record CUDA event. " << e.what(); |
| } |
| } else if (config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) { |
| auto fallback = kineto_ctx_ptr->fallback_; |
| TORCH_INTERNAL_ASSERT(fallback != nullptr); |
| torch::profiler::impl::privateuse1Stubs()->record( |
| nullptr, &fallback->device_event_end_, nullptr); |
| } |
| |
| if (fn.scope() == at::RecordScope::USER_SCOPE) { |
| torch::profiler::impl::kineto::popUserCorrelationId(); |
| } else { |
| torch::profiler::impl::kineto::popCorrelationId(); |
| } |
| } |
| |
| template <bool use_global_callback = false> |
| void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) { |
| auto registration_state_ptr = |
| KinetoThreadLocalState::get(use_global_callback); |
| TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set"); |
| auto recordFunctionCallback = |
| at::RecordFunctionCallback( |
| onFunctionEnter<use_global_callback>, |
| onFunctionExit<use_global_callback>) |
| .needsInputs(registration_state_ptr->config().report_input_shapes) |
| .scopes(scopes); |
| |
| if constexpr (use_global_callback) { |
| registration_state_ptr->setCallbackHandle( |
| at::addGlobalCallback(recordFunctionCallback)); |
| } else { |
| registration_state_ptr->setCallbackHandle( |
| at::addThreadLocalCallback(recordFunctionCallback)); |
| } |
| } |
| |
| struct ProfilerStateInfo { |
| std::shared_ptr<KinetoThreadLocalState> state_ptr; |
| std::unordered_set<at::RecordScope> scopes; |
| }; |
| std::shared_ptr<ProfilerStateInfo> profiler_state_info_ptr{nullptr}; |
| |
| } // namespace |
| |
| void reportBackendEventToActiveKinetoProfiler( |
| const int64_t start_time_us, |
| const int64_t end_time_us, |
| const int64_t debug_handle, |
| const at::RecordScope scope, |
| const std::string& event_name, |
| const std::string& backend_name) { |
| TORCH_INTERNAL_ASSERT( |
| KinetoThreadLocalState::get(/*global=*/true) == nullptr, |
| "On-demand profiling does not support post processing callback"); |
| |
| auto state_ptr = KinetoThreadLocalState::get(/*global=*/false); |
| if (!state_ptr) { |
| return; |
| } |
| |
| state_ptr->recordQueue.getSubqueue()->emplace_backend_event( |
| start_time_us, |
| end_time_us, |
| debug_handle, |
| scope, |
| event_name, |
| backend_name); |
| |
| /* no support for input shapes now? |
| if (config.report_input_shapes) { |
| ctx_ptr->shapes = inputSizes(fn); |
| ctx_ptr->dtypes = inputTypes(fn); |
| } |
| */ |
| } |
| |
| void prepareProfiler( |
| const torch::profiler::impl::ProfilerConfig& config, |
| const std::set<torch::profiler::impl::ActivityType>& activities) { |
| if (config.state == ProfilerState::NVTX || |
| config.state == ProfilerState::ITT) { |
| return; |
| } |
| TORCH_CHECK( |
| config.state == ProfilerState::KINETO || |
| config.state == ProfilerState::KINETO_GPU_FALLBACK || |
| config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK, |
| "Supported only in Kineto profiler"); |
| torch::profiler::impl::kineto::prepareTrace( |
| /*cpuOnly=*/!( |
| at::hasCUDA() || at::hasXPU() || at::hasMTIA() || |
| c10::get_privateuse1_backend() != "privateuseone"), |
| activities, |
| config.experimental_config); |
| |
| if (!config.experimental_config.performance_events.empty()) { |
| /* For now only CPU activity is supported */ |
| TORCH_CHECK( |
| activities.count(torch::autograd::profiler::ActivityType::CPU), |
| "Cannot run cpu hardware profiler without CPU activities, please only use CPU activity type"); |
| /* |
| * Sending a warning and passing the non-standard event to the backend |
| * Backend can abort if the event is not supported. |
| * TODO Should we gracefully drop the invalid event if we have atleast one |
| * valid? |
| */ |
| auto is_standard_event = [](const std::string& event) -> bool { |
| for (auto e : torch::profiler::ProfilerPerfEvents) { |
| if (!std::strcmp(event.c_str(), e)) { |
| return true; |
| } |
| } |
| return false; |
| }; |
| |
| for (const auto& e : config.experimental_config.performance_events) { |
| if (!is_standard_event(e)) { |
| TORCH_WARN("Forwarding a non-standard CPU performance event : ", e); |
| } |
| } |
| } |
| } |
| |
| static void toggleTorchOpCollectionDynamic(bool enable) { |
| auto state_ptr = ProfilerStateBase::get(); |
| if (state_ptr) { |
| const auto& config = state_ptr->config(); |
| if (enable) { |
| auto scopes = profiler_state_info_ptr->scopes; |
| config.global() ? pushProfilingCallbacks</*global=*/true>(scopes) |
| : pushProfilingCallbacks</*global=*/false>(scopes); |
| } else { |
| state_ptr->removeCallback(); |
| } |
| } |
| } |
| |
| // Set this function to be unused as profiler implementation needs more |
| // refactoring to support Python ops collection dynamic toggling |
| #ifdef _MSC_VER |
| #define UNUSED |
| #else |
| #define UNUSED __attribute__((unused)) |
| #endif |
| static UNUSED void togglePythonCollectionDynamic(bool enable) { |
| auto state_ptr = ProfilerStateBase::get(); |
| if (state_ptr) { |
| auto global = state_ptr->config().global(); |
| KinetoThreadLocalState* kineto_thread_local_state_ptr = |
| KinetoThreadLocalState::get(global); |
| if (enable) { |
| kineto_thread_local_state_ptr->resumePython(); |
| } else { |
| kineto_thread_local_state_ptr->pausePython(); |
| } |
| } |
| } |
| |
| static void toggleCPUCollectionDynamic(bool enable) { |
| toggleTorchOpCollectionDynamic(enable); |
| // For now we only support Torch Op collection dynamic toggling as |
| // implementing Python ops would require not only string parsing to get rid of |
| // the toggling events as well as other unfinished events as well as changes |
| // in stack logic |
| // togglePythonCollectionDynamic(enable); |
| } |
| |
| void toggleCollectionDynamic( |
| const bool enable, |
| const std::set<torch::profiler::impl::ActivityType>& activities) { |
| if (activities.count(torch::autograd::profiler::ActivityType::CPU) > 0 && |
| activities.count(torch::autograd::profiler::ActivityType::CUDA) == 0) { |
| LOG(WARNING) |
| << "Toggling CPU activity with CUDA activity on may result in traces with CUDA events on artibrary tracks"; |
| } |
| for (auto act : activities) { |
| if (act == torch::autograd::profiler::ActivityType::CUDA) { |
| torch::profiler::impl::kineto::toggleCollectionDynamic(enable); |
| } else if (act == torch::autograd::profiler::ActivityType::CPU) { |
| toggleCPUCollectionDynamic(enable); |
| } else { |
| LOG(WARNING) |
| << "Dynamic toggle is only supported for CPU/GPU activity, skipping toggling of " |
| << actToString(act); |
| continue; |
| } |
| } |
| } |
| |
| void enableProfilerWithEventPostProcess( |
| const torch::profiler::impl::ProfilerConfig& config, |
| const std::set<torch::profiler::impl::ActivityType>& activities, |
| post_process_t&& cb, |
| const std::unordered_set<at::RecordScope>& scopes) { |
| TORCH_CHECK( |
| config.state != ProfilerState::NVTX, |
| "NVTX does not support post processing callback."); |
| TORCH_CHECK( |
| config.state != ProfilerState::ITT, |
| "ITT does not support post processing callback."); |
| TORCH_INTERNAL_ASSERT( |
| KinetoThreadLocalState::get(/*global=*/true) == nullptr, |
| "On-demand profiling does not support post processing callback"); |
| |
| enableProfiler(config, activities, scopes); |
| auto state_ptr = KinetoThreadLocalState::get(config.global()); |
| state_ptr->setEventPostProcessingCallback(std::move(cb)); |
| } |
| |
| void enableProfiler( |
| const torch::profiler::impl::ProfilerConfig& config, |
| const std::set<torch::profiler::impl::ActivityType>& activities, |
| const std::unordered_set<at::RecordScope>& scopes) { |
| const auto has_cpu = activities.count(ActivityType::CPU); |
| TORCH_CHECK( |
| KinetoThreadLocalState::get(/*global=*/config.global()) == nullptr, |
| "Profiler is already enabled", |
| (config.global() ? "." : " on this thread.")); |
| |
| if (config.state == ProfilerState::NVTX) { |
| torch::profiler::impl::pushNVTXCallbacks(config, scopes); |
| return; |
| } else if (config.state == ProfilerState::ITT) { |
| torch::profiler::impl::pushITTCallbacks(config, scopes); |
| return; |
| } else if (config.state == ProfilerState::PRIVATEUSE1) { |
| torch::profiler::impl::pushPRIVATEUSE1CallbacksStub(config, scopes); |
| return; |
| } |
| |
| TORCH_CHECK( |
| config.state == ProfilerState::KINETO || |
| config.state == ProfilerState::KINETO_GPU_FALLBACK || |
| config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK || |
| config.global()); |
| TORCH_CHECK(!activities.empty(), "No activities specified."); |
| TORCH_INTERNAL_ASSERT( |
| has_cpu || !config.global(), |
| "Ondemand profiling must enable CPU tracing"); |
| |
| auto state_ptr = std::make_shared<KinetoThreadLocalState>(config, activities); |
| KinetoThreadLocalState::push(state_ptr); |
| |
| if (has_cpu) { |
| config.global() ? pushProfilingCallbacks</*global=*/true>(scopes) |
| : pushProfilingCallbacks</*global=*/false>(scopes); |
| } |
| |
| if (!config.global()) { |
| torch::profiler::impl::kineto::startTrace(); |
| } |
| |
| if (has_cpu) { |
| auto state_info_ptr = std::make_shared<ProfilerStateInfo>(); |
| state_info_ptr->state_ptr = state_ptr; |
| state_info_ptr->scopes = scopes; |
| profiler_state_info_ptr = state_info_ptr; |
| } |
| } |
| |
| bool isProfilerEnabledInMainThread() { |
| return profiler_state_info_ptr != nullptr; |
| } |
| |
| void enableProfilerInChildThread() { |
| auto state_info_ptr = profiler_state_info_ptr; |
| TORCH_CHECK(state_info_ptr, "Profiler is not enabled in main thread."); |
| TORCH_CHECK( |
| KinetoThreadLocalState::get(/*global=*/false) == nullptr, |
| "Profiler is already enabled in this thread."); |
| |
| KinetoThreadLocalState::push(state_info_ptr->state_ptr); |
| pushProfilingCallbacks</*global=*/false>(state_info_ptr->scopes); |
| } |
| |
| void disableProfilerInChildThread() { |
| auto state_ptr = ProfilerStateBase::pop(); |
| TORCH_CHECK( |
| state_ptr, |
| "Can't disable Kineto profiler when it's not running in this thread"); |
| state_ptr->removeCallback(); |
| } |
| |
| std::unique_ptr<ProfilerResult> disableProfiler() { |
| // releasing to inform child threads to stop profiling |
| profiler_state_info_ptr = nullptr; |
| |
| auto state_ptr = ProfilerStateBase::pop(); |
| const auto& config = state_ptr->config(); |
| TORCH_CHECK( |
| state_ptr && |
| (config.state == ProfilerState::KINETO || |
| config.state == ProfilerState::KINETO_GPU_FALLBACK || |
| config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK || |
| config.state == ProfilerState::KINETO_ONDEMAND || |
| config.state == ProfilerState::NVTX || |
| config.state == ProfilerState::ITT || |
| config.state == ProfilerState::PRIVATEUSE1), |
| "Can't disable Kineto profiler when it's not running"); |
| |
| state_ptr->removeCallback(); |
| |
| // Traces are converged via libkineto automatically for ondemand flow |
| if (state_ptr->config().global()) { |
| (void)std::static_pointer_cast<KinetoThreadLocalState>(state_ptr) |
| ->finalizeTrace(); |
| return std::make_unique<ProfilerResult>(); |
| } |
| |
| // Shared among NVTX, PRIVATEUSE1, KINETO, KINETO_GPU_FALLBACK, |
| // KINETO_PRIVATEUSE1_FALLBACK |
| std::unique_ptr<ProfilerResult> result; |
| if (state_ptr->config().state == ProfilerState::NVTX || |
| state_ptr->config().state == ProfilerState::PRIVATEUSE1) { |
| result = std::make_unique<ProfilerResult>(); |
| } |
| |
| if (config.state == ProfilerState::KINETO || |
| config.state == ProfilerState::KINETO_GPU_FALLBACK || |
| config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) { |
| auto kineto_state_ptr = |
| std::static_pointer_cast<KinetoThreadLocalState>(state_ptr); |
| auto trace = kineto_state_ptr->finalizeTrace(); |
| result = std::make_unique<ProfilerResult>( |
| kineto_state_ptr->startTime, |
| std::move(kineto_state_ptr->kinetoEvents), |
| std::move(trace), |
| std::move(kineto_state_ptr->eventTree)); |
| } |
| |
| return result; |
| } |
| |
| KinetoEvent::KinetoEvent( |
| const std::shared_ptr<const torch::profiler::impl::Result>& result, |
| const bool verbose) |
| : result_{result} { |
| TORCH_INTERNAL_ASSERT(result != nullptr); |
| |
| if (verbose) { |
| // Populate Python stack |
| auto parent = result_->parent_.lock(); |
| while (parent != nullptr) { |
| parent->visit_if_base<PyExtraFieldsBase>( |
| [&](const auto&) { python_stack_.push_back(parent->name()); }); |
| parent = parent->parent_.lock(); |
| } |
| } |
| |
| result->visit_if_base<ExtraFields<EventType::TorchOp>>([&](const auto& op) { |
| auto arg_data = parseArgData(op.inputs_, op.concrete_inputs_); |
| shapes_ = std::move(arg_data.shapesForKinetoEvent); |
| dtypes_ = std::move(arg_data.dtypes); |
| concrete_inputs_ = std::move(arg_data.concreteInputs); |
| kwinputs_ = std::move(op.kwinputs_); |
| }); |
| } |
| |
| bool KinetoEvent::isPythonFunction() const { |
| bool out{false}; |
| result_->visit_if_base<PyExtraFieldsBase>([&](const auto&) { out = true; }); |
| return out; |
| } |
| |
| bool KinetoEvent::hasShapes() const { |
| return !shapes_.empty(); |
| } |
| |
| const c10::ArrayRef<std::vector<int64_t>> KinetoEvent::shapes() const { |
| return shapes_; |
| } |
| |
| bool KinetoEvent::hasTypes() const { |
| return !dtypes_.empty(); |
| } |
| |
| const c10::ArrayRef<std::string> KinetoEvent::dtypes() const { |
| return dtypes_; |
| } |
| |
| bool KinetoEvent::hasConcreteInputs() const { |
| return !concrete_inputs_.empty(); |
| } |
| |
| const c10::ArrayRef<c10::IValue> KinetoEvent::concreteInputs() const { |
| return concrete_inputs_; |
| } |
| |
| bool KinetoEvent::hasKwinputs() const { |
| return !kwinputs_.empty(); |
| } |
| |
| const std::unordered_map<std::string, c10::IValue> KinetoEvent::kwinputs() |
| const { |
| return kwinputs_; |
| } |
| |
| const c10::ArrayRef<std::string> KinetoEvent::stack() const { |
| auto get = [&](const auto& i) -> auto& { |
| return !i.jit_stack_.empty() ? i.jit_stack_ : python_stack_; |
| }; |
| |
| auto const& extra_fields = result_->extra_fields_; |
| if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) { |
| return get(*p); |
| } |
| if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) { |
| return get(*p); |
| } |
| return python_stack_; |
| } |
| |
| const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const { |
| auto const& extra_fields = result_->extra_fields_; |
| if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) { |
| return p->jit_modules_; |
| } |
| if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) { |
| return p->jit_modules_; |
| } |
| return {}; |
| } |
| |
| uint64_t KinetoEvent::endNs() const { |
| return result_->endTimeNS(); |
| } |
| |
| uint64_t KinetoEvent::durationNs() const { |
| return (result_->endTimeNS() - result_->start_time_ns_); |
| } |
| |
| int64_t KinetoEvent::debugHandle() const { |
| return result_->visit(c10::overloaded( |
| [](const ExtraFields<EventType::TorchOp>& i) { return i.debug_handle_; }, |
| [](const ExtraFields<EventType::Backend>& i) { return i.debug_handle_; }, |
| [](const auto&) -> int64_t { return -1; })); |
| } |
| |
| int KinetoEvent::deviceIndex() const { |
| return result_->visit(c10::overloaded( |
| [](const ExtraFields<EventType::Allocation>& i) { |
| return static_cast<int>(i.device_index_); |
| }, |
| [](const ExtraFields<EventType::OutOfMemory>& i) { |
| return static_cast<int>(i.device_index_); |
| }, |
| [&](const auto&) { |
| return static_cast<int>(result_->kineto_info_.device); |
| })); |
| } |
| |
| bool KinetoEvent::hasStack() const { |
| return !stack().empty(); |
| } |
| |
| int64_t KinetoEvent::cudaElapsedUs() const { |
| auto cuda_event_start = fallbackStart(); |
| auto cuda_event_end = fallbackEnd(); |
| if (!cuda_event_start || !cuda_event_end) { |
| return -1; |
| } |
| try { |
| return (int64_t)torch::profiler::impl::cudaStubs()->elapsed( |
| &cuda_event_start, &cuda_event_end); |
| } catch (std::exception& e) { |
| LOG(WARNING) << "Failed to measure time between two CUDA events. " |
| << e.what(); |
| } |
| return -1; |
| } |
| |
| int64_t KinetoEvent::privateuse1ElapsedUs() const { |
| auto privateuse1_event_start = fallbackStart(); |
| auto privateuse1_event_end = fallbackEnd(); |
| if (!privateuse1_event_start || !privateuse1_event_end) { |
| return -1; |
| } |
| return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed( |
| &privateuse1_event_start, &privateuse1_event_end); |
| return -1; |
| } |
| |
| void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const { |
| return result_->visit(c10::overloaded( |
| [&in](const ExtraFields<EventType::TorchOp>& e) -> void { |
| const size_t n = e.perf_event_counters_->size(); |
| // should be rare |
| if (in.size() < n) { |
| in.resize(n, 0); |
| } |
| for (size_t i = 0; i < n; ++i) { |
| in[i] = (*e.perf_event_counters_)[i]; |
| } |
| }, |
| [](const auto&) -> void { return; })); |
| } |
| |
| #define FORWARD_FROM_RESULT(method_name, result_expr) \ |
| decltype(std::declval<KinetoEvent>().method_name()) \ |
| KinetoEvent::method_name() const { \ |
| return static_cast<decltype(std::declval<KinetoEvent>().method_name())>( \ |
| result_->result_expr); \ |
| } |
| |
| FORWARD_FROM_RESULT(startThreadId, start_tid_) |
| FORWARD_FROM_RESULT(endThreadId, endTID()) |
| FORWARD_FROM_RESULT(activityType, kinetoType()) |
| FORWARD_FROM_RESULT(name, name()) |
| FORWARD_FROM_RESULT(deviceType, deviceType()) |
| FORWARD_FROM_RESULT(startNs, start_time_ns_) |
| FORWARD_FROM_RESULT(correlationId, correlationID()) |
| FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource) |
| #undef FORWARD_FROM_RESULT |
| |
| // Most of the fields in `KinetoEvent` only make sense for a single event type. |
| // (Generally TorchOp.) For all other types they simply return the default |
| // value. This macro provides a succinct way of expressing this behavior. |
| #define TYPED_ATTR_WITH_DEFAULT( \ |
| event_type, method_name, expression, default_value) \ |
| decltype(std::declval<KinetoEvent>().method_name()) \ |
| KinetoEvent::method_name() const { \ |
| using out_t = decltype(std::declval<KinetoEvent>().method_name()); \ |
| return result_->visit(c10::overloaded( \ |
| [](const ExtraFields<EventType::event_type>& e) -> out_t { \ |
| return expression; \ |
| }, \ |
| [](const auto&) -> out_t { return default_value; })); \ |
| } |
| |
| #define TYPED_ATTR(event_type, method_name, expression) \ |
| TYPED_ATTR_WITH_DEFAULT(event_type, method_name, expression, {}) |
| |
| TYPED_ATTR_WITH_DEFAULT(TorchOp, sequenceNr, e.sequence_number_, -1) |
| TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0) |
| TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_)) |
| TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty()) |
| TYPED_ATTR(TorchOp, isAsync, e.is_async_) |
| TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_) |
| TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_) |
| TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_) |
| TYPED_ATTR( |
| TorchOp, |
| flops, |
| !e.extra_args_.empty() |
| ? torch::profiler::impl::computeFlops(e.name_, e.extra_args_) |
| : 0) |
| TYPED_ATTR(Backend, backend, e.backend_) |
| TYPED_ATTR(Allocation, nBytes, e.alloc_size_) |
| TYPED_ATTR(Kineto, linkedCorrelationId, [&]() { |
| const auto linked = e.linked_activity_.lock(); |
| return linked ? linked->correlationID() : 0; |
| }()) |
| #undef TYPED_ATTR |
| #undef TYPED_ATTR_WITH_DEFAULT |
| |
| ProfilerResult::ProfilerResult( |
| uint64_t start_time, |
| std::vector<KinetoEvent> events, |
| std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&& |
| trace, |
| std::vector<experimental_event_t>&& event_tree) |
| : trace_start_ns_(start_time), |
| events_(std::move(events)), |
| trace_(std::move(trace)), |
| event_tree_(std::move(event_tree)) {} |
| ProfilerResult::ProfilerResult() = default; |
| ProfilerResult::~ProfilerResult() = default; |
| |
| void ProfilerResult::save(const std::string& path) { |
| trace_->save(path); |
| } |
| |
| } // namespace autograd::profiler |
| |
| namespace profiler::impl { |
| void _reportVulkanEventToProfiler(vulkan_id_t id) { |
| auto state_ptr = ::torch::autograd::profiler::KinetoThreadLocalState::get( |
| /*global=*/false); |
| if (state_ptr) { |
| state_ptr->reportVulkanEventToProfiler(id); |
| } |
| } |
| } // namespace profiler::impl |
| |
| } // namespace torch |