blob: dd783bc574aa8422dd50b2dec2830c0871e9947a [file] [log] [blame]
#pragma once
#include <c10/util/ThreadLocalDebugInfo.h>
#include <string>
namespace torch {
class MobileDebugInfo : public c10::DebugInfoBase {
public:
const std::string& getModelName() {
return model_name_;
}
void setModelName(const std::string& model_name) {
model_name_ = model_name;
}
const std::string& getMethodName() {
return method_name_;
}
void setMethodName(const std::string& method_name) {
method_name_ = method_name;
}
size_t getOpIdx() {
return op_idx_;
}
void setOpIdx(size_t op_idx) {
op_idx_ = op_idx;
}
virtual ~MobileDebugInfo() {}
private:
std::string model_name_;
std::string method_name_;
size_t op_idx_ = 0;
};
class MobileModuleObserver {
public:
virtual ~MobileModuleObserver() = default;
virtual void onEnter(
const std::string& model_name,
const std::string& method_name) {}
virtual void onExit() {}
};
class MobileObserverConfig {
public:
void setModuleObserver(std::unique_ptr<MobileModuleObserver> reporter) {
module_observer_ = std::move(reporter);
}
MobileModuleObserver* getModuleObserver() {
return module_observer_.get();
}
private:
std::unique_ptr<MobileModuleObserver> module_observer_;
};
MobileObserverConfig& observerConfig();
} // namespace torch