blob: 76ced6fb7b53d44a56393e45d3143569159210ce [file] [log] [blame]
#pragma once
#include <c10/util/Exception.h>
#include <vector>
#include <mutex>
namespace torch { namespace autograd { namespace utils {
// Warning handler for multi-threaded contexts. Gather warnings from
// all threads into a single queue, then process together at the end
// in the main thread.
class DelayWarningHandler: public at::WarningHandler {
public:
~DelayWarningHandler() override = default;
void replay_warnings();
private:
void process(const at::SourceLocation &source_location,
const std::string &msg,
bool verbatim) override;
struct Warning {
c10::SourceLocation source_location;
std::string msg;
bool verbatim;
};
std::vector<Warning> warnings_;
std::mutex mutex_;
};
}}} // namespace torch::autograd::utils