blob: a8e57b01ae1cef4b03c11d7605034442e6eb015f [file] [log] [blame]
#include "caffe2/core/net.h"
#include "caffe2/utils/proto_utils.h"
#include "compiler.h"
#include "parser.h"
namespace caffe2 {
namespace script {
struct DefCompiler {
DefCompiler(const Def& def, NetDef& net_def)
: def(def), net_def_stack({&net_def}) {}
void run() {
cur().set_name(def.name().name());
for (auto input : def.params()) {
auto& name = input.ident().name();
map(name, name);
// cur().add_external_input(name);
}
for (auto output : def.returns()) {
auto& name = output.ident().name();
map(name, name);
// cur().add_external_output(name);
}
emitStatements(def.statements());
}
void emitStatements(const ListView<TreeRef>& statements) {
for (auto stmt : statements) {
switch (stmt->kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
default:
throw ErrorReport(stmt) << "NYI: " << stmt;
}
}
}
void map(const std::string& name, const std::string& value) {
env[name] = value;
}
const std::string& lookup(const Ident& ident) {
if (env.count(ident.name()) == 0)
throw ErrorReport(ident) << "undefined value " << ident.name();
return env[ident.name()];
}
void emitAssignment(const Assign& stmt) {
auto op = emit(stmt.rhs());
while (op->output_size() < stmt.idents().size())
op->add_output();
int i = 0;
for (auto ident : stmt.idents()) {
op->set_output(i++, ident.name());
map(ident.name(), ident.name());
}
}
void emitIf(const If& stmt) {
auto cond = getValue(stmt.cond());
auto op = cur().add_op();
op->set_type("If");
op->add_input(cond);
auto true_branch = op->add_arg();
true_branch->set_name("then_net");
auto nd = true_branch->mutable_n();
net_def_stack.push_back(nd);
emitStatements(stmt.trueBranch());
net_def_stack.pop_back();
if (stmt.falseBranch().size() > 0) {
auto false_branch = op->add_arg();
false_branch->set_name("else_net");
auto nd = false_branch->mutable_n();
net_def_stack.push_back(nd);
emitStatements(stmt.falseBranch());
net_def_stack.pop_back();
}
}
void emitWhile(const While& stmt) {
std::string loop_var = fresh();
emitConst(0, loop_var, ""); // it needs a definition before loop
auto op = cur().add_op();
op->set_type("While");
auto cond = op->add_arg();
cond->set_name("cond_net");
auto cond_net = cond->mutable_n();
net_def_stack.push_back(cond_net);
auto cond_op = emit(stmt.cond());
cond_op->set_output(0, loop_var);
net_def_stack.pop_back();
op->add_input(loop_var);
auto body = op->add_arg();
body->set_name("loop_net");
auto body_net = body->mutable_n();
net_def_stack.push_back(body_net);
emitStatements(stmt.body());
net_def_stack.pop_back();
}
const std::string& getValue(const TreeRef& tree) {
switch (tree->kind()) {
case TK_IDENT:
return lookup(Ident(tree));
default:
auto op = emit(tree);
return op->output(0);
}
}
std::string fresh() {
return std::string("$t") + caffe2::to_string(next_fresh++);
}
const char* operatorName(int kind, int ninputs) {
switch (kind) {
case '+':
return "Add";
case '-':
if (ninputs == 1)
return "Negative";
else
return "Sub";
case '*':
return "Mul";
case '/':
return "Div";
case TK_NE:
return "NE";
case TK_EQ:
return "EQ";
case '<':
return "LT";
case '>':
return "GT";
case TK_LE:
return "LE";
case TK_GE:
return "GE";
default:
throw std::runtime_error("unknown kind " + caffe2::to_string(kind));
}
}
void fillArg(Argument* arg, const Attribute& attr) {
std::string name = attr.name().name();
arg->set_name(name);
auto value = attr.value();
// TODO: handle non-float attributes
switch (value->kind()) {
case TK_CONST: {
auto v = value->tree(0)->doubleValue();
auto f = value->tree(1)->stringValue();
if (f == "f")
arg->set_f(v);
else
arg->set_i(v);
} break;
case TK_LIST:
for (auto t : value->trees()) {
auto v = t->tree(0)->doubleValue();
auto f = t->tree(1)->stringValue();
if (f == "f")
arg->add_floats(v);
else
arg->add_ints(v);
}
break;
}
}
template <typename Trees>
std::vector<std::string> getValues(const Trees& trees) {
std::vector<std::string> result;
for (const auto& tree : trees) {
result.push_back(getValue(tree));
}
return result;
}
OperatorDef* emit(const TreeRef& tree) {
switch (tree->kind()) {
case TK_IDENT: {
auto op = cur().add_op();
op->set_type("Copy");
op->add_input(lookup(Ident(tree)));
op->add_output(fresh());
return op;
} break;
case TK_NE:
case TK_EQ:
case '<':
case '>':
case TK_LE:
case TK_GE:
case '-':
case '*':
case '/':
case '+': {
// must be before add_op
auto values = getValues(tree->trees());
auto op = cur().add_op();
op->set_type(operatorName(tree->kind(), tree->trees().size()));
for (auto& v : values) {
op->add_input(v);
}
op->add_output(fresh());
auto broadcast = op->add_arg();
broadcast->set_name("broadcast");
broadcast->set_i(1);
return op;
}
case TK_APPLY: {
auto apply = Apply(tree);
// must be before add_op
auto values = getValues(apply.inputs());
auto op = cur().add_op();
op->set_type(apply.name().name());
for (auto& v : values) {
op->add_input(v);
}
// assume 1 output unless matched to more
op->add_output(fresh());
for (auto attribute : apply.attributes()) {
fillArg(op->add_arg(), attribute);
}
return op;
} break;
case TK_CONST: {
return emitConst(
tree->tree(0)->doubleValue(),
fresh(),
tree->tree(1)->stringValue());
} break;
default:
throw ErrorReport(tree) << "NYI: " << tree;
}
}
OperatorDef* emitConst(
double v,
const std::string& output,
const std::string& type_ident) {
auto op = cur().add_op();
op->set_type("ConstantFill");
auto dtype = op->add_arg();
dtype->set_name("dtype");
auto value = op->add_arg();
value->set_name("value");
if (type_ident == "f") {
dtype->set_i(TensorProto_DataType_FLOAT);
value->set_f(v);
} else if (type_ident == "LL") {
dtype->set_i(TensorProto_DataType_INT64);
value->set_i(v);
} else {
dtype->set_i(TensorProto_DataType_INT32);
value->set_i(v);
}
auto shape = op->add_arg();
shape->set_name("shape");
shape->add_ints(1);
op->add_output(output);
return op;
}
NetDef& cur() {
return *net_def_stack.back();
}
const Def& def;
std::unordered_map<std::string, std::string>
env; // map from name in Def to name in NetDef
std::vector<NetDef*> net_def_stack;
int next_fresh = 0;
};
struct CompilationUnitImpl {
CompilationUnitImpl() {}
void defineFunction(const Def& def) {
if (functions.count(def.name().name()) > 0) {
throw ErrorReport(def) << def.name().name() << " already defined.";
}
DefCompiler c(
def, functions.emplace(def.name().name(), NetDef()).first->second);
c.run();
}
void define(const std::string& str) {
Parser p(str);
while (p.lexer().cur().kind != TK_EOF) {
defineFunction(Def(p.parseFunction()));
}
}
std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& str) {
if (functions.count(str) == 0)
throw ErrorReport() << "undefined function: " << str << "\n";
auto& net_def = functions[str];
return caffe2::CreateNet(net_def, ws);
}
private:
std::unordered_map<std::string, NetDef> functions;
};
CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {}
void CompilationUnit::define(const std::string& str) {
return pImpl->define(str);
}
std::unique_ptr<NetBase> CompilationUnit::createNet(
Workspace* ws,
const std::string& str) {
return pImpl->createNet(ws, str);
}
CompilationUnit::~CompilationUnit() {}
} // namespace script
} // namespace caffe2