blob: 33c49c0ec9b760464c4aec8700f0054b226d7da0 [file] [log] [blame]
#include <utility>
#include "caffe2/opt/annotations.h"
namespace caffe2 {
using namespace nom::repr;
void Caffe2Annotation::setOperatorDef(const caffe2::OperatorDef& opDef) {
OpDef = opDef;
OpDefExists = true;
}
bool Caffe2Annotation::hasOperatorDef() const {
return OpDefExists;
}
const caffe2::OperatorDef& Caffe2Annotation::getOperatorDef() const {
CAFFE_ENFORCE(
OpDefExists,
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
return OpDef;
}
caffe2::OperatorDef* Caffe2Annotation::getMutableOperatorDef() {
CAFFE_ENFORCE(
OpDefExists,
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
return &OpDef;
}
// Distributed annotations
void Caffe2Annotation::setDeviceOption(const caffe2::DeviceOption& devOpt) {
*OpDef.mutable_device_option() = devOpt;
}
bool Caffe2Annotation::hasDeviceOption() const {
return OpDef.has_device_option();
}
const caffe2::DeviceOption& Caffe2Annotation::getDeviceOption() const {
CAFFE_ENFORCE(
hasDeviceOption(),
"DeviceOption was never set. Use Caffe2Annotation::setDeviceOption.");
return OpDef.device_option();
}
caffe2::DeviceOption* Caffe2Annotation::getMutableDeviceOption() {
CAFFE_ENFORCE(
hasDeviceOption(),
"DeviceOption was never set. Use Caffe2Annotation::setDeviceOption.");
return OpDef.mutable_device_option();
}
void Caffe2Annotation::setDevice(std::string device) {
Device = device;
}
const std::string Caffe2Annotation::getDevice() const {
return Device;
}
void Caffe2Annotation::setDeviceType(int device) {
DeviceType = device;
}
int Caffe2Annotation::getDeviceType() const {
return DeviceType;
}
void Caffe2Annotation::setParallelization(
Caffe2Annotation::ParallelizationScheme s,
int num) {
parallelization_scheme_ = s;
parallelization_ = num;
}
Caffe2Annotation::ParallelizationScheme
Caffe2Annotation::getParallelizationScheme() const {
return parallelization_scheme_;
}
int Caffe2Annotation::getParallelization() const {
return parallelization_;
}
void Caffe2Annotation::setKeyNode(NNGraph::NodeRef n) {
key_node_ = n;
}
const NNGraph::NodeRef& Caffe2Annotation::getKeyNode() const {
CAFFE_ENFORCE(key_node_, "No key node has been annotated");
return key_node_;
}
void Caffe2Annotation::setLengthNode(NNGraph::NodeRef n) {
length_node_ = n;
}
const NNGraph::NodeRef& Caffe2Annotation::getLengthNode() const {
CAFFE_ENFORCE(length_node_, "No length node has been annotated");
return length_node_;
}
void Caffe2Annotation::setComponentLevels(std::vector<std::string> components) {
component_levels_ = std::move(components);
}
std::vector<std::string> Caffe2Annotation::getComponentLevels() const {
return component_levels_;
}
bool Caffe2Annotation::classof(const Annotation* A) {
return A->getKind() == AnnotationKind::Caffe2;
}
} // namespace caffe2