blob: b301bb24bd3a6b7d4150d7dfcae1d5ed71e32aba [file] [log] [blame]
#include "caffe2/caffe2/fb/predictor/Transforms.h"
namespace caffe2 {
namespace {
string
NextBlob(const Workspace& ws, const string& prefix, int max_tries = 1000000) {
for (int i = 0; i < max_tries; ++i) {
std::stringstream stream;
stream << prefix;
if (i) {
stream << '_' << i;
}
if (!ws.HasBlob(stream.str())) {
return stream.str();
}
}
CAFFE_THROW("Failed to find a new blob name");
return "";
}
bool HasInput(const string& blob, const OperatorDef& op) {
for (auto in : op.input()) {
if (blob == in) {
return true;
}
}
return false;
}
bool HasOutput(const string& blob, const OperatorDef& op) {
for (auto out : op.output()) {
if (blob == out) {
return true;
}
}
return false;
}
/*
Tests if it is valid to simply rename "from" to "to" after start (at or after
start) ignoring in-place ops of operator types in ignoreTypes. There is a set of
ignoreTypes instead of a single ignoreType for uses cases where multiple
op-types will be in-placed; in such cases this function will give a more optimal
answer if it can consider all types at the same time.
*/
bool CanRenameForwards(
const string& from,
const string& to,
const std::shared_ptr<NetDef>& net,
const std::set<string>& netOutputs,
const std::set<string>& ignoreTypes,
int start) {
bool redefined_to = false;
for (int i = start; i < net->op_size(); i++) {
auto op = net->op(i);
bool uses_from = HasInput(from, op);
// If there's a use of "from" after a redefine of "to" then we can't rename
// this "from" to "to" b/c it will use the wrong "to" (this op's "to"
// instead of the "to" that would be produced by renaming ops[start-1] to
// output "to" to "from")
if (redefined_to && uses_from) {
VLOG(7) << "CanRenameForwards " << to << " is redefined before " << from
<< " is used. Cannot rename from " << from << " to " << to;
return false;
}
// If "to" is redefined then we have to be careful of in-placing this op or
// blocking future uses of "from"
if (HasOutput(to, op)) {
// If this op also uses "from", then renaming "from" to "to" will make
// this op inplace
if (uses_from) {
// In-placing of this op is not allowed
if (!ignoreTypes.count(op.type())) {
VLOG(7) << "CanRenameForwards detected in-placing of op of type "
<< op.type();
return false;
}
VLOG(7) << "CanRenameForwards will make " << op.DebugString()
<< " in-place, but this is in the okay-types-to-inplace "
<< "whitelist";
}
// This op won't be inplaced (or it's okay if it is) but we still have to
// watch out for future uses of "from"
redefined_to = true;
VLOG(7) << "CanRenameForwards " << to << " is redefined at " << i;
}
// If this op redefines "from" then renaming will stop with the inputs of
// this op. Since we haven't found any problems with renaming, it's okay to
// rename. Note that we don't need to check if "from" is a network output,
// as this op will produce "from"
if (HasOutput(from, op)) {
VLOG(7) << "CanRenameForwards found another op making " << from
<< " so it's fine to rename " << from << " to " << to
<< " in earlier ops of the net.";
return true;
}
}
// We reached the end of the ops. There have been no redefinitions of "from",
// so if "from" is needed in the network outputs then we can't rename it
return !netOutputs.count(from);
}
bool CanRenameBackwards(
const string& from,
const string& to,
const std::shared_ptr<NetDef>& net,
const std::set<string>& netInputs,
const std::set<string>& netOutputs,
const std::set<string>& ignoreTypes,
int end) {
for (int i = end; i >= 0; i--) {
auto op = net->op(i);
// If this op defines "to", then all ops after this point will use this op's
// "to" instead of the producer-of-"from"s, so the producer can't be renamed
// to produce "to" instead of "from"
// FUTURE_POSSIBILITY: We might be able to rename this op to not produce to
if (HasOutput(to, op)) {
VLOG(7) << "CanRenameBackwards " << to << " is defined after " << from
<< " is. Cannot rename.";
return false;
}
// Because of the previous question, we know that no op in between this op
// and end has "to" as an output, so it is impossible to make any of them
// in-place by adding "to" as an input to any of them
// If we find the producer of "from", then we will stop renaming backwards
// (we won't rename this ops inputs)
if (HasOutput(from, op)) {
// If this op has "to" as an input, then renaming "from" to "to" will
// in-place this op
if (HasInput(to, op)) {
if (!ignoreTypes.count(op.type())) {
// In future, you could maybe check if "to" could be renamed to a
// brand new unique blob name
VLOG(7) << "CanRenameBackwards will in-place producer of " << from;
return false;
}
VLOG(7) << "CanRenameBackwards will in-place the producer of " << from
<< " but this is okay because it's in our in-placeable "
<< "whitelist";
}
// This op won't be in-placed (or it will but it's still okay), but we
// have to check the forwards logic too
// TODO why? what's the specific case again?
VLOG(7) << "CanRenameBackwards found the parent of " << from
<< ". Recursively testing if we can rename the parent.";
return CanRenameForwards(from, to, net, netOutputs, ignoreTypes, i + 1);
}
// After this point, inputs of this op may be renamed
// If this blob uses "to", then renaming the producer of "from" to produce
// "to" will interfere with this op. Technically it will only interfere if
// the producer of "from" would overwrite this op's "to", but if that wasn't
// the case then there's some op that produces "to" after the producer of
// "from", and this will be caught in the first if (HasOutput(to, op))
if (HasInput(to, op)) {
VLOG(7) << "CanRenameBackwards " << to << " is used after " << from
<< " is defined. Cannot rename.";
return false;
}
}
// Found no parent, so must be a network output. We cannot rename it
CAFFE_ENFORCE(netInputs.count(from));
VLOG(7) << "CanRenameBackwards " << from << " is a network input. Cannot "
<< "rename.";
return false;
}
void RenameInputs(const string& from, const string& to, OperatorDef* def) {
VLOG(6) << "RenameInputs(from=" << from << ", to=" << to << ", "
<< def->DebugString() << ")";
for (int i = 0; i < def->input_size(); i++) {
if (def->input(i) == from) {
*def->mutable_input(i) = to;
}
}
}
void RenameOutputs(const string& from, const string& to, OperatorDef* def) {
VLOG(6) << "RenameOutputs(from=" << from << ", to=" << to << ", "
<< def->DebugString() << ")";
for (string& output : *def->mutable_output()) {
if (output == from) {
output = to;
}
}
}
void RenameInputsInChildren(
const string& from,
const string& to,
std::shared_ptr<caffe2::NetDef> net,
int pidx) {
// This does NOT continue through in-place ops
VLOG(4) << "RenameInputsInChildren(from=" << from << ", to=" << to;
for (int j = pidx + 1; j < net->op_size(); j++) {
if (HasInput(from, net->op(j))) {
RenameInputs(from, to, net->mutable_op(j));
}
// If any child op redefines from, then future ops no longer use this op's
// (at j) version of from
if (HasOutput(from, net->op(j))) {
return;
}
}
}
} // namespace
void InPlaceOps(const InferenceGraph& graph, const std::string& op_type) {
int num_inplaced = 0;
auto net = graph.predict_net_def;
// Collect blob names that we can never rename
std::set<string> netInputs(
graph.parameter_names.begin(), graph.parameter_names.end());
netInputs.insert(graph.input_names.begin(), graph.input_names.end());
std::set<string> netOutputs(
graph.output_names.begin(), graph.output_names.end());
// In-place ops greedily in a forward manner
for (int i = 0; i < net->op_size(); i++) {
OperatorDef op = net->op(i);
// Only inplace the requested
if (op.type() != op_type) {
VLOG(2) << "InPlaceObs: Type is " << op_type << ". Not in-placing";
continue;
}
if (op.input_size() != 1 || op.output_size() != 1) {
LOG(ERROR) << "InPlaceOps only supports ops with exactly 1 output "
<< "and exactly 1 input. Skipping op " << op.DebugString();
continue;
}
const string& in = op.input(0);
const string& out = op.output(0);
// If it's already in place then let's not do any more work
if (in == out) {
continue;
}
// Otherwise check if we can rename things
bool can_rename_forwards =
CanRenameForwards(out, in, net, netOutputs, {op_type}, i + 1);
// If renaming is impossible (or complicated) then skip this op
if (!can_rename_forwards &&
!CanRenameBackwards(
in, out, net, netInputs, netOutputs, {op_type}, i - 1)) {
VLOG(2) << "InPlaceOps: Complicated or impossible remove for op: "
<< op.DebugString();
continue;
}
VLOG(2) << "InPlaceOps will inplace " << op.DebugString();
num_inplaced++;
// Handle renaming
if (can_rename_forwards) {
// Rename out to in
VLOG(3) << "InPlaceOps can rename in children from " << out << " to "
<< in;
RenameInputsInChildren(out, in, net, i);
} else {
// Since out is an output of the network, we must rename in parents of op
VLOG(3) << "InPlaceOps must find parent that produced " << in
<< " and rename in all of its children from " << in << " to "
<< out;
for (int pidx = i - 1; pidx >= 0; pidx--) {
if (HasOutput(in, net->op(pidx))) {
VLOG(5) << "InPlaceOps found parent is "
<< net->op(pidx).DebugString();
RenameOutputs(in, out, net->mutable_op(pidx));
RenameInputsInChildren(in, out, net, pidx);
break;
}
}
}
} // For every op
VLOG(1) << "InPlaceOps(" << op_type << ") renamed " << num_inplaced << " ops";
}
void RemoveOpsByType(const InferenceGraph& graph, const std::string& op_type) {
int num_removed = 0;
auto net = graph.predict_net_def;
// Rename all the ops we want to delete
InPlaceOps(graph, op_type);
// Now the only ops we can delete are inplaced ones
for (int i = 0; i < net->op_size(); i++) {
OperatorDef op = net->op(i);
// Only remove ops of the requested type
if (op.type() != op_type) {
VLOG(2) << "RemoveOpsByType: Type is " << op_type << ". Not removing";
continue;
}
if (op.input_size() != 1 || op.output_size() != 1) {
LOG(ERROR) << "RemoveOpsByType only supports ops with exactly 1 output "
<< "and exactly 1 input. Skipping op " << op.DebugString();
continue;
}
const string& in = op.input(0);
const string& out = op.output(0);
// If the op is in-place then we can always delete it
if (in == out) {
VLOG(1) << "RemoveOpsByType(" << op_type << ") deleting inplace op";
net->mutable_op()->erase(net->mutable_op()->begin() + i);
i--;
num_removed++;
} else {
VLOG(2) << "RemoveOpsByType(" << op_type << ") can't delete.";
}
} // For every op
VLOG(1) << "RemoveOpsByType(" << op_type << ") removed " << num_removed
<< " ops";
}
} // namespace caffe2