blob: f021d263106d6a9224f1d6346ebf389dde08b242 [file] [log] [blame]
#include "glow_net_transform.h"
#include <caffe2/opt/onnxifi_transformer.h>
#include <caffe2/opt/shape_info.h>
#include <caffe2/utils/string_utils.h>
#include <unordered_set>
C10_DEFINE_bool(onnxifi_debug_mode, false, "Enable onnxifi debug mode.");
C10_DEFINE_bool(
onnxifi_adjust_batch,
true,
"Attach AdjustBatch ops at input/outputs of the Onnxifi ops");
C10_DEFINE_bool(
onnxifi_loop_test_mode,
false,
"For test purpose only. Build a dummy net just to test the functionality");
C10_DEFINE_bool(
enforce_fp32_inputs_into_fp16,
false,
"Whether to enforce fp32 to fp16 conversion for external inputs.");
C10_DEFINE_bool(
merge_fp32_inputs_into_fp16,
false,
"Merge all the fp32 input tensors into one, convert it to fp16 and split it back");
C10_DEFINE_int32(
onnxifi_min_ops,
1,
"Minimum number of ops for a subgraph to be lowered to backend");
C10_DEFINE_int32(
onnxifi_timeout_ms,
0,
"Timeout limit for onnxifi inference in milliseconds. 0 means no timeout");
C10_DEFINE_string(
onnxifi_shape_hints,
"",
"Shape hints in the form of Name:d0,d1:d2;");
C10_DEFINE_string(
onnxifi_blacklist,
"",
"A list of net positions whose corresponding op will be ignored "
"to onnxifi. Example 0-50,61,62-70");
C10_DEFINE_string(
onnxifi_blacklist_ops,
"",
"A list of operator types that will be ignored "
"to onnxifi. Example Tanh,Mul");
C10_DEFINE_string(
onnxifi_input_output_observe_list,
"",
"A list of net positions whose corresponding op's inputs and outputs will be"
" observed. ");
namespace caffe2 {
namespace glow {
// The list in in the form of "0-3,5,6-7" which means, we will black list ops
// with net positions in [0,1,2,3,5,6,7]
std::unordered_set<int> ParseNetPositionList(const std::string& str) {
std::unordered_set<int> net_position_list;
if (str.empty()) {
return net_position_list;
}
auto tokens = caffe2::split(',', str);
for (const auto& token : tokens) {
if (token == "-1") {
net_position_list.emplace(-1);
continue;
}
auto range = caffe2::split('-', token);
if (range.size() == 1) {
net_position_list.emplace(std::stoi(range[0]));
} else if (range.size() == 2) {
int from = std::stoi(range[0]);
int to = std::stoi(range[1]);
for (int i = from; i <= to; ++i) {
net_position_list.emplace(i);
}
} else if (range.size() > 2) {
LOG(WARNING) << "Ignoring illegal range: " << token;
}
}
return net_position_list;
}
std::unordered_set<std::string> ParseBlackListOps(const std::string& str) {
std::unordered_set<std::string> ops;
if (str.empty()) {
return ops;
}
auto tokens = caffe2::split(',', str);
for (const auto& token : tokens) {
ops.emplace(token);
}
return ops;
}
// Carrying out the ONNXIFI transform
void onnxifi(
NetDef* net,
Workspace* ws,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::string>& weight_names,
const std::unordered_set<int>& blacklist,
const ShapeInfoMap& shape_hints,
bool use_onnx,
size_t max_batch_size,
size_t max_seq_size,
bool load_model_by_blob,
bool predictor_net_ssa_rewritten) {
// Split SparseLengthsSumSparse so that we can lower the SparseLengthsSum part
splitSparseLengthsSumSparse(net, *ws);
// Clean up the external input/output of the net
net->mutable_external_input()->Clear();
net->mutable_external_output()->Clear();
for (const auto& i : input_names) {
net->add_external_input(i);
}
for (const auto& w : weight_names) {
net->add_external_input(w);
}
for (const auto& o : output_names) {
net->add_external_output(o);
}
// ONNXIFI transform
OnnxifiTransformerOptions opts;
opts.use_onnx = use_onnx;
opts.bound_shape_spec.max_batch_size = max_batch_size;
opts.bound_shape_spec.max_seq_size = max_seq_size;
opts.debug = FLAGS_onnxifi_debug_mode;
opts.adjust_batch = FLAGS_onnxifi_adjust_batch;
opts.min_ops = FLAGS_onnxifi_min_ops;
opts.load_model_by_blob = load_model_by_blob;
opts.enforce_fp32_inputs_into_fp16 = FLAGS_enforce_fp32_inputs_into_fp16;
opts.merge_fp32_inputs_into_fp16 = FLAGS_merge_fp32_inputs_into_fp16;
opts.loop_test = FLAGS_onnxifi_loop_test_mode;
opts.predictor_net_ssa_rewritten = predictor_net_ssa_rewritten;
opts.timeout = FLAGS_onnxifi_timeout_ms;
ShapeInfoMap more_shape_hints = shape_hints;
if (!FLAGS_onnxifi_shape_hints.empty()) {
parseShapeInfoMapFromString(FLAGS_onnxifi_shape_hints, more_shape_hints);
}
// Before applying backlist, make sure the ops in the net all have an net_pos;
caffe2::BackendTransformerBase::annotateOpIndex(net);
// Parse the blacklist
auto more_blacklist = ParseNetPositionList(FLAGS_onnxifi_blacklist);
for (const auto& b : blacklist) {
more_blacklist.emplace(b);
}
// ONNX mode will change the op order so it doesn't apply here
if (!opts.use_onnx) {
auto blacklisted_ops = ParseBlackListOps(FLAGS_onnxifi_blacklist_ops);
for (const auto& op : net->op()) {
if (blacklisted_ops.count(op.type())) {
ArgumentHelper helper(op);
more_blacklist.emplace(helper.GetSingleArgument(op, kNetPos, -1));
}
}
}
// Attach observation nodes
//
// When we want to observe intermediate tensors value out of the onnxifi op,
// we use the following trick:
//
// 1. for specified op, we find its input and outputs.
// 2. for each input and output, we create a new copy op and attach it as an
// input to the copy.
// 3. we blacklist these new copy operators from onnxification. This forces
// these intermediate tensors to also become outputs of the onnxifi op.
// 4. we put the right arguments on the copy ops so TensorObserver can print
// out the values.
auto ops_to_observe =
ParseNetPositionList(FLAGS_onnxifi_input_output_observe_list);
std::unordered_set<std::string> tensors_to_observe;
for (const auto& op : ops_to_observe) {
if (op >= net->op().size()) {
CAFFE_THROW(
"Cannot observe operator at position ", op, " (out of range)");
}
const auto& op_to_observe = net->op(op);
tensors_to_observe.insert(
op_to_observe.input().begin(), op_to_observe.input().end());
if ((op_to_observe.type() == "Concat" ||
op_to_observe.type() == "Reshape") &&
op_to_observe.output().size() == 2) {
tensors_to_observe.insert(op_to_observe.output(0));
} else {
tensors_to_observe.insert(
op_to_observe.output().begin(), op_to_observe.output().end());
}
}
for (const auto& tensor : tensors_to_observe) {
OperatorDef copy_op;
copy_op.set_type("Copy");
copy_op.add_input(tensor);
copy_op.add_output(tensor + "_copy_output_ignore");
auto pos = net->op().size();
AddArgument(kNetPos, pos, &copy_op);
AddArgument("observe_input_tensors", 1, &copy_op);
net->add_op()->CopyFrom(copy_op);
more_blacklist.emplace(pos);
}
OnnxifiTransformer ts(opts);
ts.transform(ws, net, weight_names, more_shape_hints, more_blacklist);
if (FLAGS_onnxifi_debug_mode) {
WriteProtoToTextFile(*net, "debug_transformed_net.pb_txt");
}
// Cleanup the input from the workspace
for (const auto& i : input_names) {
ws->RemoveBlob(i);
}
}
} // namespace glow
} // namespace caffe2