| #pragma once |
| |
| #include <c10/macros/Export.h> |
| |
| #include <memory> |
| #include <string> |
| |
| namespace at { |
| |
| // Thread local debug information is propagated across the forward |
| // (including async fork tasks) and backward passes and is supposed |
| // to be utilized by the user's code to pass extra information from |
| // the higher layers (e.g. model id) down to the operator callbacks |
| // (e.g. used for logging) |
| |
| class CAFFE2_API ThreadLocalDebugInfoBase { |
| public: |
| ThreadLocalDebugInfoBase() {} |
| virtual ~ThreadLocalDebugInfoBase() {} |
| }; |
| |
| CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase> |
| getThreadLocalDebugInfo() noexcept; |
| |
| // Sets thread local debug information, returns the previously set |
| // debug information |
| CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase> |
| setThreadLocalDebugInfo( |
| std::shared_ptr<ThreadLocalDebugInfoBase> info) noexcept; |
| |
| class CAFFE2_API DebugInfoGuard { |
| public: |
| explicit DebugInfoGuard( |
| std::shared_ptr<ThreadLocalDebugInfoBase> info) { |
| prev_info_ = setThreadLocalDebugInfo(std::move(info)); |
| } |
| |
| ~DebugInfoGuard() { |
| setThreadLocalDebugInfo(std::move(prev_info_)); |
| } |
| |
| private: |
| std::shared_ptr<ThreadLocalDebugInfoBase> prev_info_; |
| }; |
| |
| } // namespace at |