blob: 981a881a6148ca637e0c7656747702f7afcf0f46 [file] [log] [blame]
#include "caffe2/opt/converter.h"
#include "caffe2/utils/cast.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace {
using namespace nom;
using namespace nom::repr;
class BatchMatMulConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::BatchMatMul>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::BatchMatMul>(nnOp.get());
if (argMap.count("trans_a")) {
CAFFE_ENFORCE(argMap["trans_a"].has_i(), "Invalid axis argument");
int trans_a = static_cast<int>(argMap["trans_a"].i());
c->setTransA(!!trans_a);
}
if (argMap.count("trans_b")) {
CAFFE_ENFORCE(argMap["trans_b"].has_i(), "Invalid add_axis argument");
int trans_b = static_cast<int>(argMap["trans_b"].i());
c->setTransB(!!trans_b);
}
if (argMap.count("broadcast")) {
CAFFE_ENFORCE(argMap["broadcast"].has_i(), "Invalid add_axis argument");
int broadcast = static_cast<int>(argMap["broadcast"].i());
c->setBroadcast(!!broadcast);
}
return nnOp;
}
// Does not override default converter to OperatorDef
~BatchMatMulConverter() override {}
};
REGISTER_CONVERTER(BatchMatMul, BatchMatMulConverter);
TRIVIAL_CONVERTER(BatchGather);
REGISTER_CONVERTER(BatchGather, BatchGatherConverter);
class MulConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::Mul>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::Mul>(nnOp.get());
if (argMap.count("broadcast")) {
CAFFE_ENFORCE(argMap["broadcast"].has_i(), "Invalid broadcast argument");
int broadcast = static_cast<int>(argMap["broadcast"].i());
c->setBroadcast(!!broadcast);
}
return nnOp;
}
// Does not override default converter to OperatorDef
~MulConverter() override {}
};
REGISTER_CONVERTER(Mul, MulConverter);
class AddConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::Add>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::Add>(nnOp.get());
if (argMap.count("broadcast")) {
CAFFE_ENFORCE(argMap["broadcast"].has_i(), "Invalid broadcast argument");
int broadcast = static_cast<int>(argMap["broadcast"].i());
c->setBroadcast(!!broadcast);
}
return nnOp;
}
// Does not override default converter to OperatorDef
~AddConverter() override {}
};
REGISTER_CONVERTER(Add, AddConverter);
class CastConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::Cast>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::Cast>(nnOp.get());
ArgumentHelper helper(op);
c->setTo(cast::GetCastDataType(helper, "to"));
return nnOp;
}
// Does not override default converter to OperatorDef
~CastConverter() override {}
};
REGISTER_CONVERTER(Cast, CastConverter);
class ReplaceNaNConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::ReplaceNaN>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::ReplaceNaN>(nnOp.get());
if (argMap.count("value")) {
CAFFE_ENFORCE(argMap["value"].has_f(), "Invalid 'value' argument");
float value = static_cast<float>(argMap["value"].f());
c->setValue(value);
}
return nnOp;
}
// Does not override default converter to OperatorDef
~ReplaceNaNConverter() override {}
};
REGISTER_CONVERTER(ReplaceNaN, ReplaceNaNConverter);
class ConcatAddMulReplaceNaNClipConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::ConcatAddMulReplaceNaNClip>();
auto argMap = getArgumentsFromOperator(op);
auto c = dyn_cast<repr::ConcatAddMulReplaceNaNClip>(nnOp.get());
if (argMap.count("clip_min")) {
CAFFE_ENFORCE(argMap["clip_min"].has_f(), "Invalid 'clip_min' argument");
c->setClipMin(static_cast<float>(argMap["clip_min"].f()));
}
if (argMap.count("clip_max")) {
CAFFE_ENFORCE(argMap["clip_max"].has_f(), "Invalid 'clip_max' argument");
c->setClipMin(static_cast<float>(argMap["clip_max"].f()));
}
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto cc_amrc = dyn_cast<repr::ConcatAddMulReplaceNaNClip>(nnOp);
OperatorDef op;
op.set_type("ConcatAddMulReplaceNaNClip");
auto min_arg = op.add_arg();
min_arg->set_name("clip_min");
min_arg->set_f(cc_amrc->getClipMin());
auto max_arg = op.add_arg();
max_arg->set_name("clip_max");
max_arg->set_f(cc_amrc->getClipMax());
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~ConcatAddMulReplaceNaNClipConverter() override {}
};
REGISTER_CONVERTER(
ConcatAddMulReplaceNaNClip,
ConcatAddMulReplaceNaNClipConverter);
class SliceConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::Slice>();
const caffe2::ArgumentHelper args(op);
auto c = dyn_cast<repr::Slice>(nnOp.get());
if (args.HasArgument("starts")) {
c->setStarts(args.GetRepeatedArgument<int64_t>("starts"));
}
if (args.HasArgument("ends")) {
c->setEnds(args.GetRepeatedArgument<int64_t>("ends"));
}
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto slice = dyn_cast<repr::Slice>(nnOp);
OperatorDef op;
op.set_type("Slice");
op.add_arg()->CopyFrom(
caffe2::MakeArgument<vector<int64_t>>("starts", slice->getStarts()));
op.add_arg()->CopyFrom(
caffe2::MakeArgument<vector<int64_t>>("ends", slice->getEnds()));
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~SliceConverter() override {}
};
REGISTER_CONVERTER(Slice, SliceConverter);
class ClipRangesGatherSigridHashConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::ClipRangesGatherSigridHash>();
const caffe2::ArgumentHelper args(op);
auto c = dyn_cast<repr::ClipRangesGatherSigridHash>(nnOp.get());
if (args.HasArgument("feature_indices")) {
c->setFeatureIndices(
args.GetRepeatedArgument<int64_t>("feature_indices"));
}
if (args.HasArgument("max_lengths")) {
c->setMaxLengths(args.GetRepeatedArgument<int64_t>("max_lengths"));
}
if (args.HasArgument("salts")) {
c->setSalts(args.GetRepeatedArgument<int64_t>("salts"));
}
if (args.HasArgument("max_values")) {
c->setMaxValues(args.GetRepeatedArgument<int64_t>("max_values"));
}
if (args.HasArgument("hash_into_int32")) {
c->setHashIntoInt32(
args.GetSingleArgument<bool>("hash_into_int32", false));
}
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto fuse = dyn_cast<repr::ClipRangesGatherSigridHash>(nnOp);
OperatorDef op;
op.set_type("ClipRangesGatherSigridHash");
op.add_arg()->CopyFrom(caffe2::MakeArgument<vector<int64_t>>(
"feature_indices", fuse->getFeatureIndices()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<vector<int64_t>>(
"max_lengths", fuse->getMaxLengths()));
op.add_arg()->CopyFrom(
caffe2::MakeArgument<vector<int64_t>>("salts", fuse->getSalts()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<vector<int64_t>>(
"max_values", fuse->getMaxValues()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<bool>(
"hash_into_int32", fuse->getHashIntoInt32()));
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~ClipRangesGatherSigridHashConverter() override {}
};
REGISTER_CONVERTER(
ClipRangesGatherSigridHash,
ClipRangesGatherSigridHashConverter);
class ClipRangesGatherSigridHashV2Converter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::ClipRangesGatherSigridHashV2>();
const caffe2::ArgumentHelper args(op);
auto c = dyn_cast<repr::ClipRangesGatherSigridHashV2>(nnOp.get());
if (args.HasArgument("max_lengths")) {
c->setMaxLengths(args.GetRepeatedArgument<int64_t>("max_lengths"));
}
if (args.HasArgument("salts")) {
c->setSalts(args.GetRepeatedArgument<int64_t>("salts"));
}
if (args.HasArgument("max_values")) {
c->setMaxValues(args.GetRepeatedArgument<int64_t>("max_values"));
}
if (args.HasArgument("hash_into_int32")) {
c->setHashIntoInt32(
args.GetSingleArgument<bool>("hash_into_int32", false));
}
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto fuse = dyn_cast<repr::ClipRangesGatherSigridHashV2>(nnOp);
OperatorDef op;
op.set_type("ClipRangesGatherSigridHashV2");
op.add_arg()->CopyFrom(caffe2::MakeArgument<vector<int64_t>>(
"max_lengths", fuse->getMaxLengths()));
op.add_arg()->CopyFrom(
caffe2::MakeArgument<vector<int64_t>>("salts", fuse->getSalts()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<vector<int64_t>>(
"max_values", fuse->getMaxValues()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<bool>(
"hash_into_int32", fuse->getHashIntoInt32()));
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~ClipRangesGatherSigridHashV2Converter() override {}
};
REGISTER_CONVERTER(
ClipRangesGatherSigridHashV2,
ClipRangesGatherSigridHashV2Converter);
class ClipRangesConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::ClipRanges>();
const caffe2::ArgumentHelper args(op);
auto c = dyn_cast<repr::ClipRanges>(nnOp.get());
c->setMaxLength(args.GetSingleArgument<int64_t>("max_length", 0));
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto clipRanges = dyn_cast<repr::ClipRanges>(nnOp);
OperatorDef op;
op.set_type("ClipRanges");
op.add_arg()->CopyFrom(caffe2::MakeArgument<int64_t>(
"max_length", clipRanges->getMaxLength()));
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~ClipRangesConverter() override {}
};
REGISTER_CONVERTER(ClipRanges, ClipRangesConverter);
class SigridHashConverter : public Converter {
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp =
util::make_unique<repr::SigridHash>();
const caffe2::ArgumentHelper args(op);
auto c = dyn_cast<repr::SigridHash>(nnOp.get());
c->setSalt(args.GetSingleArgument<int64_t>("salt", 0));
c->setMaxValue(args.GetSingleArgument<int64_t>("maxValue", 0));
c->setHashIntoInt32(args.GetSingleArgument<bool>("hashIntoInt32", false));
return nnOp;
}
OperatorDef convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) override {
auto sigridHash = dyn_cast<repr::SigridHash>(nnOp);
OperatorDef op;
op.set_type("SigridHash");
op.add_arg()->CopyFrom(
caffe2::MakeArgument<int64_t>("salt", sigridHash->getSalt()));
op.add_arg()->CopyFrom(
caffe2::MakeArgument<int64_t>("maxValue", sigridHash->getMaxValue()));
op.add_arg()->CopyFrom(caffe2::MakeArgument<bool>(
"hashIntoInt32", sigridHash->getHashIntoInt32()));
op.mutable_device_option()->CopyFrom(getDeviceOption(nnOp));
return op;
}
~SigridHashConverter() override {}
};
REGISTER_CONVERTER(SigridHash, SigridHashConverter);
} // namespace
} // namespace caffe2