blob: 131960040da4939a45f41e1a46650a73ea77a358 [file] [log] [blame]
#include "caffe2/opt/tvm_transformer.h"
#include "caffe2/opt/backend_cutting.h"
C10_DEFINE_bool(
caffe2_tvm_profiling_based_jit,
true,
"Use profiling based jit for TVM transform");
C10_DEFINE_int32(
caffe2_tvm_min_ops,
10,
"Minimal number of supported ops for the subgraph to be lowered to TVM");
namespace caffe2 {
NetDef TvmTransformer::buildTvmOp(
const caffe2::NetDef& net,
const std::unordered_set<std::string>& weights,
const ShapeInfoMap& shape_hints) {
if (opts_.min_ops > net.op_size()) {
return net;
}
caffe2::NetDef net_opt;
auto* op = net_opt.add_op();
op->set_type("TVMJit");
// Remove the second output of Concat/Reshape from external_output.
// And figure out what primary inputs of the net is sequence look-ups
std::unordered_set<std::string> split_infos;
std::unordered_set<std::string> input_set(
net.external_input().begin(), net.external_input().end());
std::unordered_set<std::string> seq_input_set;
for (const auto& op0 : net.op()) {
if ((op0.type() == "Concat" || op0.type() == "Reshape") &&
op0.output_size() == 2) {
split_infos.emplace(op0.output(1));
} else if (
op0.type() == "SparseLengthsSum" ||
op0.type() == "SparseLengthsSumFused8BitRowwise") {
// The indices input of SparseLengthSum should be of SEQ type
if (op0.input_size() > 1 && input_set.count(op0.input(1))) {
seq_input_set.emplace(op0.input(1));
}
} else if (
op0.type() == "SparseLengthsWeightedSum" ||
op0.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
// The weight and indices inputs of SparseLengthWeightedSum should be of
// SEQ type
if (op0.input_size() > 1 && input_set.count(op0.input(1))) {
seq_input_set.emplace(op0.input(1));
}
if (op0.input_size() > 2 && input_set.count(op0.input(2))) {
seq_input_set.emplace(op0.input(2));
}
}
}
// C2 operator bind input/output by position (they can be rewritten by e.g.
// Memonger) while TVM runtime bind them by name. Therefore, we need to record
// the input/output names.
auto* input_arg = op->add_arg();
input_arg->set_name("inputs");
auto* output_arg = op->add_arg();
output_arg->set_name("outputs");
// We expose both inputs and weights as inputs of TVMJitOp
for (const auto& i : net.external_input()) {
net_opt.add_external_input(i);
op->add_input(i);
input_arg->add_strings(i);
}
for (const auto& i : net.external_output()) {
if (split_infos.count(i)) {
continue;
}
net_opt.add_external_output(i);
op->add_output(i);
output_arg->add_strings(i);
}
// Record the referred weights
auto* w_arg = op->add_arg();
std::unordered_set<std::string> referred_weights;
for (const auto& op0 : net.op()) {
for (const auto& i : op0.input()) {
if (weights.count(i)) {
referred_weights.emplace(i);
}
}
}
w_arg->set_name("weights");
for (const auto& w : referred_weights) {
w_arg->add_strings(w);
}
// Add input shape info in "input_shape_info" argument of the net
if (!opts_.profiling_based_jit) {
auto* shape_arg = op->add_arg();
shape_arg->set_name("input_shape_info");
for (const auto& i : net_opt.external_input()) {
shape_arg->mutable_tensors()->Add()->CopyFrom(
wrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
}
}
// Add original net as a fallback
auto* original_net_arg = op->add_arg();
original_net_arg->set_name("original_net");
original_net_arg->mutable_n()->CopyFrom(net);
// Add model id
AddArgument("model_id", model_id_, op);
// Add op id
AddArgument("tvm_op_id", tvm_op_id_++, op);
// Add nominal batch size
AddArgument("nominal_batch_size", opts_.bound_shape_spec.max_batch_size, op);
// Add nominal sequence size
AddArgument("nominal_seq_size", opts_.bound_shape_spec.max_seq_size, op);
// Indices of the input blobs with sequence type
auto* seq_input_indices_arg = op->add_arg();
seq_input_indices_arg->set_name("seq_input_indices");
int64_t input_idx = 0;
for (const auto& input : net_opt.external_input()) {
if (seq_input_set.count(input)) {
seq_input_indices_arg->add_ints(input_idx);
}
++input_idx;
}
if (opts_.debug) {
AddArgument("debug", 1, op);
}
if (opts_.profiling_based_jit) {
AddArgument("profiling_based_jit", 1, op);
}
return net_opt;
}
// Cutting off the runnable part and replace with TVMJitOPs. Asssume the nets
// were topologically sorted
void TvmTransformer::transform(
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& weight_names,
const ShapeInfoMap& input_shape_hints,
const std::unordered_set<int>& blacklisted_ops) {
CAFFE_ENFORCE(ws);
CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
// Save the args of the net so that we can copy it to opt net later
std::vector<Argument> args;
for (const auto& arg : pred_net->arg()) {
args.emplace_back(arg);
}
// Get model id and reset TVM op id to 0
model_id_ = getModelId(*pred_net);
tvm_op_id_ = 0;
std::unordered_set<std::string> weights(
weight_names.begin(), weight_names.end());
// SSA Rewrite the net
auto shape_hints_mapped =
ssaRewriteAndMapNames(ws, pred_net, input_shape_hints);
// Populate shape info
Workspace mapped_ws(ws, input_mapping_);
ShapeInfoMap shape_hints;
if (!opts_.profiling_based_jit) {
shape_hints = inferShapes(
&mapped_ws, pred_net, shape_hints_mapped, opts_.bound_shape_spec);
}
if (opts_.debug) {
dumpNet(*pred_net, shape_hints, "debug_ssa_net.pb_txt");
}
// We are ready to transform the net
NetDef net_opt =
applyTvmTransform(pred_net, weights, blacklisted_ops, shape_hints);
// Copy the properties
for (const auto& arg : args) {
net_opt.add_arg()->CopyFrom(arg);
}
net_opt.mutable_device_option()->CopyFrom(pred_net->device_option());
pred_net->Swap(&net_opt);
if (opts_.debug) {
dumpNet(*pred_net, shape_hints, "debug_full_opt_net.pb_txt");
}
}
NetDef TvmTransformer::applyTvmTransform(
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
const std::unordered_set<int>& blacklisted_ops,
const ShapeInfoMap& shape_hints) {
auto profiling_based_jit = opts_.profiling_based_jit;
auto tvm_supports = [&blacklisted_ops, &shape_hints, &profiling_based_jit](
const caffe2::OperatorDef& op) {
const static std::unordered_set<std::string> supported_ops{
"Add", "Sum",
"FC", "FCTransposed",
"Flatten", "Relu",
"Sigmoid", "Softmax",
"Split", "EnsureCPUOutput",
"Reshape", "ExpandDims",
"Concat", "BatchMatMul",
"MatMul", "BatchGather",
"DotProduct", "Transpose",
"Mul", "Tanh",
"Logit", "Cast",
"Copy"};
try {
// If the op position is black listed, return false
int pos =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, kNetPos, -1);
if (blacklisted_ops.count(pos)) {
LOG(INFO) << "Blacklisting op" << op.type() << " at position " << pos;
return false;
}
// If we don't have proper shape info for the op, we cannot compile it
// properly, return false
if (!profiling_based_jit) {
for (const auto& i : op.input()) {
if (shape_hints.find(i) == shape_hints.end()) {
LOG(INFO) << "Skipping op " << op.type()
<< " due to missing shape info for input " << i;
return false;
}
}
}
// If TVM c2 frontend doesn't support this op, return false
// TODO: This should be something like TVMC2Frontend::supports(op);
return (supported_ops.count(op.type()) != 0);
} catch (const std::exception& ex) {
LOG(ERROR) << "Caught exception when querying op " << op.type()
<< ", what: " << ex.what();
return false;
}
};
auto tvm_op_converter =
[this, &weights, &shape_hints](const caffe2::NetDef& net) {
return buildTvmOp(net, weights, shape_hints);
};
return opt::OptimizeForBackend(*pred_net, tvm_supports, tvm_op_converter);
}
void tvmTransform(
NetDef* net,
Workspace* ws,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::string>& weight_names,
const ShapeInfoMap& shape_hints,
const std::unordered_set<int>& blacklisted_ops,
size_t max_batch_size,
size_t max_seq_size,
bool debug) {
TvmTransformOptions opts;
opts.bound_shape_spec.max_batch_size = max_batch_size;
opts.bound_shape_spec.max_seq_size = max_seq_size;
opts.debug = debug;
opts.profiling_based_jit = FLAGS_caffe2_tvm_profiling_based_jit;
opts.min_ops = FLAGS_caffe2_tvm_min_ops;
TvmTransformer ts(opts);
// Clean up the external input/output of the net
net->mutable_external_input()->Clear();
net->mutable_external_output()->Clear();
for (const auto& i : input_names) {
net->add_external_input(i);
}
for (const auto& w : weight_names) {
net->add_external_input(w);
}
for (const auto& o : output_names) {
net->add_external_output(o);
}
ts.transform(ws, net, weight_names, shape_hints, blacklisted_ops);
}
} // namespace caffe2