blob: 16a76573dbc138156d52f9808153c021ae542134 [file] [log] [blame]
#include "caffe2/core/net.h"
#include "caffe2/utils/proto_utils.h"
#include "compiler.h"
#include "parser.h"
namespace caffe2 {
namespace script {
namespace {
static std::unordered_set<std::string> ops_containing_nets = {
"If",
"While",
"RecurrentNetwork",
};
// record of defined function
// NetDef + metadata
struct FunctionDefinition {
explicit FunctionDefinition(Def tree)
: tree(new Def(tree)), net_def(new NetDef()) {}
explicit FunctionDefinition(std::unique_ptr<NetDef> def)
: tree(nullptr), net_def(std::move(def)) {
// we coop extern_inputs/extern_outputs to be the inputs/outputs to
// this net as a function
// but we _dont_ set these when creating the net in the workspace
// because they require the net to have valid inputs/outputs
inputs.insert(
inputs.begin(),
net_def->external_input().begin(),
net_def->external_input().end());
outputs.insert(
outputs.begin(),
net_def->external_output().begin(),
net_def->external_output().end());
net_def->clear_external_output();
net_def->clear_external_input();
}
bool isExtern() const {
return tree == nullptr;
}
std::unique_ptr<Def> tree;
std::unique_ptr<NetDef> net_def;
std::vector<std::string> inputs;
std::vector<std::string> outputs;
};
} // namespace
using SymbolTable = std::unordered_map<std::string, FunctionDefinition>;
struct DefCompiler {
DefCompiler(FunctionDefinition& def, SymbolTable& symbol_table)
: def(def),
net_def_stack({def.net_def.get()}),
symbol_table(symbol_table) {}
void run() {
auto& tree = *def.tree;
cur().set_name(tree.name().name());
for (auto input : tree.params()) {
auto& name = input.ident().name();
map(name, name);
def.inputs.push_back(name);
}
for (auto output : tree.returns()) {
auto& name = output.ident().name();
map(name, name);
def.outputs.push_back(name);
}
emitStatements(tree.statements());
}
void emitExpressionStatement(TreeRef stmt) {
// expression with no used outputs
emit(stmt, {});
}
void emitStatements(const ListView<TreeRef>& statements) {
for (auto stmt : statements) {
switch (stmt->kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
case TK_GLOBAL:
for (auto ident : stmt->trees()) {
auto name = Ident(ident).name();
map(name, name);
}
break;
default:
emitExpressionStatement(stmt);
break;
}
}
}
void map(const std::string& name, const std::string& value) {
env[name] = value;
}
const std::string& lookup(const Ident& ident) {
if (env.count(ident.name()) == 0)
throw ErrorReport(ident) << "undefined value " << ident.name();
return env[ident.name()];
}
void emitAssignment(const Assign& stmt) {
std::vector<std::string> outputs;
for (auto lhs : stmt.lhs()) {
std::string name = getLHS(lhs);
// use of "_" gets renamed in Caffe2 graphs so that two uses
// don't unintentionally interfere with each other
if (name == "_") {
name = fresh();
}
outputs.push_back(name);
}
if (stmt.reduction() != '=') {
if (stmt.lhs().size() != 1) {
throw ErrorReport(stmt)
<< "reductions are only allow when there is a single variable "
<< "on the left-hand side.";
}
auto lhs = stmt.lhs()[0];
auto expr =
Compound::create(stmt.reduction(), stmt.range(), {lhs, stmt.rhs()});
emit(expr, outputs);
} else {
emit(stmt.rhs(), outputs);
}
int i = 0;
for (auto ident : stmt.lhs()) {
if (ident->kind() == TK_IDENT)
map(Ident(ident).name(), outputs.at(i));
i++;
}
}
void emitIf(const If& stmt) {
auto cond = getValue(stmt.cond());
auto op = cur().add_op();
op->set_type("If");
op->add_input(cond);
auto true_branch = op->add_arg();
true_branch->set_name("then_net");
auto nd = true_branch->mutable_n();
net_def_stack.push_back(nd);
emitStatements(stmt.trueBranch());
net_def_stack.pop_back();
if (stmt.falseBranch().size() > 0) {
auto false_branch = op->add_arg();
false_branch->set_name("else_net");
auto nd = false_branch->mutable_n();
net_def_stack.push_back(nd);
emitStatements(stmt.falseBranch());
net_def_stack.pop_back();
}
}
void emitWhile(const While& stmt) {
std::string loop_var = fresh();
emitConst(0, loop_var, "i"); // it needs a definition before loop
auto op = cur().add_op();
op->set_type("While");
auto cond = op->add_arg();
cond->set_name("cond_net");
auto cond_net = cond->mutable_n();
net_def_stack.push_back(cond_net);
emit(stmt.cond(), {loop_var});
net_def_stack.pop_back();
op->add_input(loop_var);
auto body = op->add_arg();
body->set_name("loop_net");
auto body_net = body->mutable_n();
net_def_stack.push_back(body_net);
emitStatements(stmt.body());
net_def_stack.pop_back();
}
std::string getLHS(const TreeRef& tree) {
switch (tree->kind()) {
case TK_IDENT: {
return Ident(tree).name();
} break;
case '.': {
auto sel = Select(tree);
std::string lhs = getValue(sel.value());
// TODO: check whether this subname exists in object lhs
return lhs + "/" + sel.selector().name();
} break;
default: {
throw ErrorReport(tree)
<< "This expression cannot appear on the left-hand size of an assignment";
} break;
}
}
std::string getValue(const TreeRef& tree) {
switch (tree->kind()) {
case TK_IDENT: {
return lookup(Ident(tree));
} break;
case '.': {
auto sel = Select(tree);
std::string lhs = getValue(sel.value());
// TODO: check whether this subname exists in object lhs
return lhs + "/" + sel.selector().name();
} break;
default: {
std::string name = fresh();
emit(tree, {name});
return name;
} break;
}
}
std::string fresh(std::string prefix = "$t") {
return std::string(prefix) + c10::to_string(next_fresh++);
}
const char* operatorName(int kind, int ninputs) {
switch (kind) {
case '+':
return "Add";
case '-':
if (ninputs == 1)
return "Negative";
else
return "Sub";
case '*':
return "Mul";
case '/':
return "Div";
case TK_NE:
return "NE";
case TK_EQ:
return "EQ";
case '<':
return "LT";
case '>':
return "GT";
case TK_LE:
return "LE";
case TK_GE:
return "GE";
case TK_IF_EXPR:
return "Conditional";
case TK_AND:
return "And";
case TK_OR:
return "Or";
case TK_NOT:
return "Not";
default:
throw std::runtime_error("unknown kind " + c10::to_string(kind));
}
}
void fillArg(Argument* arg, const Attribute& attr) {
std::string name = attr.name().name();
arg->set_name(name);
auto value = attr.value();
// TODO: handle non-float attributes
switch (value->kind()) {
case TK_CONST: {
auto v = value->tree(0)->doubleValue();
auto f = value->tree(1)->stringValue();
if (f == "f")
arg->set_f(v);
else
arg->set_i(v);
} break;
case TK_LIST:
for (auto t : value->trees()) {
auto v = t->tree(0)->doubleValue();
auto f = t->tree(1)->stringValue();
if (f == "f")
arg->add_floats(v);
else
arg->add_ints(v);
}
break;
}
}
template <typename Trees>
std::vector<std::string> getValues(const Trees& trees) {
std::vector<std::string> result;
for (const auto& tree : trees) {
result.push_back(getValue(tree));
}
return result;
}
bool renameLookup(
std::unordered_map<std::string, std::string>& rename_map,
const std::string& name,
std::string& rename) {
// first look for name in the map directly
auto it = rename_map.find(name);
if (it != rename_map.end()) {
rename = it->second;
return true;
}
// otherwise if we have a rename entry like a => b and a name "a/foo/bar"
// then replace it with "b/foo/bar"
auto p = name.find("/");
if (p == std::string::npos)
return false;
it = rename_map.find(name.substr(0, p));
if (it != rename_map.end()) {
rename = it->second + name.substr(p);
return true;
}
return false;
}
void renameOp(
std::unordered_map<std::string, std::string>& rename_map,
const Apply& apply,
const std::string& prefix,
bool isExtern,
OperatorDef* new_op) {
for (size_t i = 0; i < new_op->input().size(); i++) {
auto& name = new_op->input(i);
std::string renamed;
bool defined = renameLookup(rename_map, name, renamed);
if (!isExtern && !defined) {
throw ErrorReport(apply)
<< " unexpected undefined name '" << name
<< "' while attempting to inline '" << apply.name().name() << "'";
} else if (!defined) {
// extern function using a global name, assign it an identity mapping
rename_map[name] = name;
}
new_op->set_input(i, renamed);
}
for (size_t i = 0; i < new_op->output().size(); i++) {
auto& name = new_op->output(i);
std::string renamed;
if (!renameLookup(rename_map, name, renamed)) {
renamed = prefix + name;
rename_map[name] = renamed;
}
new_op->set_output(i, renamed);
}
// handle control flow inside the op as well
if (ops_containing_nets.count(new_op->type()) > 0) {
for (size_t i = 0; i < new_op->arg_size(); i++) {
auto* arg = new_op->mutable_arg(i);
if (arg->has_n()) {
auto* n = arg->mutable_n();
for (size_t j = 0; j < n->op_size(); j++) {
renameOp(rename_map, apply, prefix, isExtern, n->mutable_op(j));
}
}
}
}
}
bool hasBypassRename(const Apply& apply) {
for (auto attr : apply.attributes()) {
if (attr.name().name() == "rename") {
if (attr.value()->kind() != TK_CONST) {
throw ErrorReport(attr.value()) << "expected a single constant";
}
return attr.value()->tree(0)->doubleValue() == 0;
}
}
return false;
}
// emit a function call by inlining the function's NetDef into our
// net def, renaming temporaries func_name<unique_id>/orig_name
// renaming only happens for values defined by the function
// that are not marked outputs
// inputs/outputs are passed by reference
void emitFunctionCall(Apply& apply, const std::vector<std::string>& outputs) {
std::string fname = apply.name().name();
std::string prefix = fresh(fname) + "/";
auto& fn = symbol_table.at(apply.name().name());
bool isExtern = fn.isExtern();
auto inputs = getValues(apply.inputs());
std::unordered_map<std::string, std::string> rename_map;
if (inputs.size() != fn.inputs.size()) {
throw ErrorReport(apply) << fname << " expected " << fn.inputs.size()
<< " values but received " << inputs.size();
}
for (size_t i = 0; i < inputs.size(); i++) {
rename_map[fn.inputs[i]] = inputs[i];
}
if (outputs.size() != fn.outputs.size()) {
throw ErrorReport(apply) << fname << " expected " << fn.outputs.size()
<< " values but received " << outputs.size();
}
for (size_t i = 0; i < outputs.size(); i++) {
rename_map[fn.outputs[i]] = outputs[i];
}
for (auto& op : fn.net_def->op()) {
auto new_op = cur().add_op();
new_op->CopyFrom(op);
if (hasBypassRename(apply)) {
prefix = "";
}
renameOp(rename_map, apply, prefix, isExtern, new_op);
}
}
void expectOutputs(
const TreeRef& tree,
const std::vector<std::string>& outputs,
size_t size) {
if (outputs.size() != size) {
throw ErrorReport(tree)
<< "expected operator to produce " << outputs.size()
<< " outputs but it produced " << size;
}
}
void appendOutputs(
const TreeRef& tree,
OperatorDef* op,
const std::vector<std::string>& outputs,
size_t size) {
expectOutputs(tree, outputs, size);
for (size_t i = 0; i < size; i++) {
op->add_output(outputs[i]);
}
}
void emitOperator(
const Apply& apply,
const OpSchema* schema,
const std::vector<std::string>& outputs) {
// must be before add_op
auto values = getValues(apply.inputs());
if (values.size() < schema->min_input() ||
values.size() > schema->max_input()) {
if (schema->min_input() == schema->max_input()) {
throw ErrorReport(apply) << "operator expects " << schema->min_input()
<< " inputs but found " << values.size();
} else {
throw ErrorReport(apply)
<< "operator takes between " << schema->min_input() << " and "
<< schema->max_input() << " inputs but found " << values.size()
<< ".";
}
}
auto numActualOutputs = schema->CalculateOutput(values.size());
if (numActualOutputs != kCannotComputeNumOutputs &&
outputs.size() != numActualOutputs) {
throw ErrorReport(apply)
<< "operator produces " << numActualOutputs
<< " outputs but matched to " << outputs.size() << " outputs";
}
auto op = cur().add_op();
op->set_type(apply.name().name());
for (auto& v : values) {
op->add_input(v);
}
// assume 1 output unless matched to more
appendOutputs(apply, op, outputs, outputs.size());
for (auto attribute : apply.attributes()) {
fillArg(op->add_arg(), attribute);
}
// Ok, we checked the stuff where we can easily give a friendly error
// message, now verify against the schema and report the error at the line
if (!schema->Verify(*op)) {
throw ErrorReport(apply) << "failed schema checking";
}
}
// Emit an operation, writing results into 'outputs'.
// This will _always_ compute something, unlike 'getValue' which simply
// returns an already computed reference if possible.
// So if 'tree' is an identifier or nested identifier (foo.bar)
// this will cause it to be _copied_ into outputs.
void emit(const TreeRef& tree, const std::vector<std::string>& outputs) {
switch (tree->kind()) {
case TK_IDENT:
case '.': {
auto op = cur().add_op();
op->set_type("Copy");
op->add_input(getValue(tree));
appendOutputs(tree, op, outputs, 1);
} break;
case TK_NE:
case TK_EQ:
case '<':
case '>':
case TK_LE:
case TK_GE:
case '-':
case '*':
case '/':
case '+':
case TK_AND:
case TK_OR:
case TK_NOT:
case TK_IF_EXPR: {
// must be before add_op
auto values = getValues(tree->trees());
auto op = cur().add_op();
op->set_type(operatorName(tree->kind(), tree->trees().size()));
for (auto& v : values) {
op->add_input(v);
}
appendOutputs(tree, op, outputs, 1);
auto broadcast = op->add_arg();
broadcast->set_name("broadcast");
broadcast->set_i(1);
} break;
case TK_APPLY: {
auto apply = Apply(tree);
// Handle built-ins like zeros, ones, etc
if (builtins.count(apply.name().name()) > 0) {
builtins[apply.name().name()](this, apply, outputs);
break;
}
if (symbol_table.count(apply.name().name()) > 0) {
emitFunctionCall(apply, outputs);
break;
}
auto schema = OpSchemaRegistry::Schema(apply.name().name());
if (schema) {
emitOperator(apply, schema, outputs);
break;
}
throw ErrorReport(apply)
<< "attempting to call unknown operation or function '"
<< apply.name().name() << "'";
} break;
case TK_CAST: {
auto cast = Cast(tree);
auto c2type = getType(cast.type());
auto input = getValue(cast.input());
auto op = cur().add_op();
op->set_type("Cast");
op->add_input(input);
appendOutputs(tree, op, outputs, 1);
auto arg = op->add_arg();
arg->set_name("to");
arg->set_i(c2type);
} break;
case TK_CONST: {
expectOutputs(tree, outputs, 1);
emitConst(
tree->tree(0)->doubleValue(),
outputs[0],
tree->tree(1)->stringValue());
} break;
case TK_GATHER: {
const auto gather = Gather(tree);
desugarAndEmitOperator(
"Gather",
gather.range(),
{gather.value(), gather.indices()},
outputs);
break;
}
case TK_SLICE: {
const auto slice = Slice(tree);
desugarAndEmitOperator(
"Slice",
slice.range(),
{slice.value(), slice.startOr(0), slice.endOr(-1)},
outputs);
break;
}
default:
throw ErrorReport(tree) << "NYI: " << tree;
break;
}
}
// Desugars constructs that are syntactic sugar and emits the corresponding
// operator invocation, e.g. tensor[indices] -> tensor.Gather(indices).
void desugarAndEmitOperator(
const std::string& operatorName,
const SourceRange& range,
TreeList&& inputs,
const std::vector<std::string>& outputs) {
const auto applyName = Ident::create(range, operatorName);
const auto applyInputs =
Compound::create(TK_LIST, range, std::move(inputs));
const auto applyAttributes = Compound::create(TK_LIST, range, {});
const auto apply =
Apply::create(range, applyName, applyInputs, applyAttributes);
const auto schema = OpSchemaRegistry::Schema(operatorName);
assert(schema != nullptr);
emitOperator(Apply(apply), schema, outputs);
}
TensorProto_DataType getType(int type) {
switch (type) {
case TK_INT:
return TensorProto_DataType_INT32;
case TK_FLOAT:
return TensorProto_DataType_FLOAT;
case TK_LONG:
return TensorProto_DataType_INT64;
case TK_BOOL:
return TensorProto_DataType_BOOL;
default:
throw std::runtime_error(
"expected type token: " + c10::to_string(type));
}
}
OperatorDef* emitConst(
double v,
const std::string& output,
const std::string& type_ident) {
auto op = cur().add_op();
op->set_type("ConstantFill");
auto dtype = op->add_arg();
dtype->set_name("dtype");
auto value = op->add_arg();
value->set_name("value");
if (type_ident == "f") {
dtype->set_i(TensorProto_DataType_FLOAT);
value->set_f(v);
} else if (type_ident == "LL") {
dtype->set_i(TensorProto_DataType_INT64);
value->set_i(v);
} else if (type_ident == "b") {
dtype->set_i(TensorProto_DataType_BOOL);
value->set_i(v != 0);
} else if (type_ident == "i") {
dtype->set_i(TensorProto_DataType_INT32);
value->set_i(v);
} else {
throw std::runtime_error("unknown type_ident " + type_ident);
}
auto shape = op->add_arg();
shape->set_name("shape");
shape->add_ints(1);
op->add_output(output);
return op;
}
NetDef& cur() {
return *net_def_stack.back();
}
FunctionDefinition& def; // the def being constructed
std::unordered_map<std::string, std::string>
env; // map from name in Def to name in NetDef
std::vector<NetDef*> net_def_stack;
SymbolTable& symbol_table;
int next_fresh = 0;
private:
void emitFillOp(const Apply& apply, const std::vector<std::string>& outputs) {
auto builtin_type = apply.name().name();
auto values = getValues(apply.inputs());
if (values.size() > 1) {
throw ErrorReport(apply)
<< "Built-in " << builtin_type << " accepts 0 or 1 inputs.";
}
bool has_shape = false;
for (const auto& attribute : apply.attributes()) {
if (attribute.name().name() == "shape") {
has_shape = true;
} else {
throw ErrorReport(apply)
<< "Unrecognized attribute " << attribute.name().name()
<< " for built-in " << builtin_type;
}
}
if (builtin_type == "zeros" || builtin_type == "ones") {
if ((values.size() != 1) && !has_shape) {
throw ErrorReport(apply)
<< "Built-in " << builtin_type
<< " requires either 1 input or 1 shape attribute";
}
} else {
// zeros_like or ones_like
if (values.size() != 1) {
throw ErrorReport(apply)
<< "Built-in " << builtin_type << " requires 1 input";
}
}
auto op = cur().add_op();
op->set_type("ConstantFill");
if (values.size()) {
op->add_input(values[0]);
auto* input_as_shape = op->add_arg();
input_as_shape->set_name("input_as_shape");
if (builtin_type.find("_like") != std::string::npos) {
// zeros_like, ones_like take the shape of the input as constant
// tensor shape
input_as_shape->set_i(0);
} else {
// zeros, ones take the values in the tensor as constant tensor
// shape
input_as_shape->set_i(1);
}
} else {
fillArg(op->add_arg(), apply.attributes()[0]);
}
auto value = op->add_arg();
value->set_name("value");
if (builtin_type.find("ones") != std::string::npos) {
value->set_f(1.0f);
} else {
value->set_f(0.0f);
}
appendOutputs(apply, op, outputs, 1);
}
// emitModule doesn't actually do anything except for allow
// statements like a = Module() to register 'a' as a valid identifier
// so that a.b = ... will work
void emitModule(const Apply& apply, const std::vector<std::string>& outputs) {
expectOutputs(apply, outputs, 1);
}
std::unordered_map<
std::string,
std::function<void(
DefCompiler*,
const Apply&,
const std::vector<std::string>& outputs)>>
builtins{{"zeros", &DefCompiler::emitFillOp},
{"zeros_like", &DefCompiler::emitFillOp},
{"ones", &DefCompiler::emitFillOp},
{"ones_like", &DefCompiler::emitFillOp},
{"Module", &DefCompiler::emitModule}};
};
struct CompilationUnitImpl {
void defineFunction(const Def& def) {
if (functions.count(def.name().name()) > 0) {
throw ErrorReport(def) << def.name().name() << " already defined.";
}
DefCompiler c(
functions.emplace(def.name().name(), FunctionDefinition(def))
.first->second,
functions);
c.run();
}
void define(const std::string& str) {
Parser p(str);
while (p.lexer().cur().kind != TK_EOF) {
defineFunction(Def(p.parseFunction()));
}
}
std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& str) {
if (functions.count(str) == 0)
throw ErrorReport() << "undefined function: " << str << "\n";
auto& def = functions.at(str);
return caffe2::CreateNet(*def.net_def, ws);
}
void defineExtern(const std::string& name, std::unique_ptr<NetDef> net_def) {
// TODO: unify extern and function namespaces
if (functions.count(name) > 0) {
throw ErrorReport() << "function '" << name << "' already defined.";
}
functions.emplace(name, FunctionDefinition(std::move(net_def)));
}
std::string getProto(const std::string& functionName) {
return functions.at(functionName).net_def->DebugString();
}
private:
friend struct DefCompiler;
SymbolTable functions;
};
CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {}
void CompilationUnit::define(const std::string& str) {
return pImpl->define(str);
}
void CompilationUnit::defineExtern(
const std::string& name,
std::unique_ptr<NetDef> nd) {
pImpl->defineExtern(name, std::move(nd));
}
std::unique_ptr<NetBase> CompilationUnit::createNet(
Workspace* ws,
const std::string& str) {
return pImpl->createNet(ws, str);
}
std::string CompilationUnit::getProto(const std::string& functionName) const {
return pImpl->getProto(functionName);
}
CompilationUnit::~CompilationUnit() {}
} // namespace script
} // namespace caffe2