blob: f464e4a4c68c7bc9172a5eb78463d4c8b25edc43 [file] [log] [blame]
#pragma once
#include <c10/macros/Export.h>
#include <memory>
#include <string>
namespace at {
// Thread local debug information is propagated across the forward
// (including async fork tasks) and backward passes and is supposed
// to be utilized by the user's code to pass extra information from
// the higher layers (e.g. model id) down to the operator callbacks
// (e.g. used for logging)
class CAFFE2_API ThreadLocalDebugInfoBase {
public:
ThreadLocalDebugInfoBase() {}
virtual ~ThreadLocalDebugInfoBase() {}
};
CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase>
getThreadLocalDebugInfo() noexcept;
// Sets thread local debug information, returns the previously set
// debug information
CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase>
setThreadLocalDebugInfo(
std::shared_ptr<ThreadLocalDebugInfoBase> info) noexcept;
class CAFFE2_API DebugInfoGuard {
public:
explicit DebugInfoGuard(
std::shared_ptr<ThreadLocalDebugInfoBase> info) {
prev_info_ = setThreadLocalDebugInfo(std::move(info));
}
~DebugInfoGuard() {
setThreadLocalDebugInfo(std::move(prev_info_));
}
private:
std::shared_ptr<ThreadLocalDebugInfoBase> prev_info_;
};
} // namespace at