blob: 4b68b8d2bfb4dce0268d30bf31c564dce3862f81 [file] [log] [blame]
#pragma once
#include "lexer.h"
#include "tree.h"
#include "tree_views.h"
namespace caffe2 {
namespace script {
struct Parser {
explicit Parser(const std::string& str)
: L(str), shared(sharedParserData()) {}
TreeRef parseIdent() {
auto t = L.expect(TK_IDENT);
// whenever we parse something that has a TreeView type we always
// use its create method so that the accessors and the constructor
// of the Compound tree are in the same place.
return Ident::create(t.range, t.text());
}
TreeRef createApply(TreeRef ident, TreeList& inputs) {
TreeList attributes;
auto range = L.cur().range;
parseOperatorArguments(inputs, attributes);
return Apply::create(
range,
ident,
List(range, std::move(inputs)),
List(range, std::move(attributes)));
}
// things like a 1.0 or a(4) that are not unary/binary expressions
// and have higher precedence than all of them
TreeRef parseBaseExp() {
TreeRef prefix;
switch (L.cur().kind) {
case TK_NUMBER:
case TK_TRUE:
case TK_FALSE: {
prefix = parseConst();
} break;
case '(': {
L.next();
prefix = parseExp();
L.expect(')');
} break;
case TK_FLOAT:
case TK_INT:
case TK_LONG: {
auto r = L.cur().range;
auto type = c(L.next().kind, r, {});
L.expect('(');
auto exp = parseExp();
L.expect(')');
prefix = Cast::create(r, type, exp);
} break;
default: {
prefix = parseIdent();
if (L.cur().kind == '(') {
TreeList inputs;
prefix = createApply(prefix, inputs);
}
} break;
}
while (true) {
if (L.nextIf('.')) {
const auto name = parseIdent();
if (L.cur().kind == '(') {
TreeList inputs = {prefix};
prefix = createApply(name, inputs);
} else {
prefix = Select::create(name->range(), prefix, name);
}
} else if (L.cur().kind == '[') {
prefix = parseSliceOrGather(prefix);
} else {
break;
}
}
return prefix;
}
TreeRef parseOptionalReduction() {
auto r = L.cur().range;
switch (L.cur().kind) {
case TK_PLUS_EQ:
case TK_MINUS_EQ:
case TK_TIMES_EQ:
case TK_DIV_EQ: {
int modifier = L.next().text()[0];
return c(modifier, r, {});
} break;
default: {
L.expect('=');
return c('=', r, {}); // no reduction
} break;
}
}
TreeRef
parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
auto cond = parseExp();
L.expect(TK_ELSE);
auto false_branch = parseExp(binary_prec);
return c(TK_IF_EXPR, range, {cond, true_branch, false_branch});
}
// parse the longest expression whose binary operators have
// precedence strictly greater than 'precedence'
// precedence == 0 will parse _all_ expressions
// this is the core loop of 'top-down precedence parsing'
TreeRef parseExp(int precedence = 0) {
TreeRef prefix = nullptr;
int unary_prec;
if (shared.isUnary(L.cur().kind, &unary_prec)) {
auto kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
prefix = c(kind, pos, {parseExp(unary_prec)});
} else {
prefix = parseBaseExp();
}
int binary_prec;
while (shared.isBinary(L.cur().kind, &binary_prec)) {
if (binary_prec <= precedence) // not allowed to parse something which is
// not greater than 'precedenc'
break;
int kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
if (shared.isRightAssociative(kind))
binary_prec--;
// special case for trinary operator
if (kind == TK_IF) {
prefix = parseTrinary(prefix, pos, binary_prec);
continue;
}
prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
}
return prefix;
}
TreeRef
parseList(int begin, int sep, int end, std::function<TreeRef(int)> parse) {
auto r = L.cur().range;
L.expect(begin);
TreeList elements;
if (L.cur().kind != end) {
int i = 0;
do {
elements.push_back(parse(i++));
} while (L.nextIf(sep));
}
L.expect(end);
return c(TK_LIST, r, std::move(elements));
}
TreeRef parseNonEmptyList(int sep, std::function<TreeRef(int)> parse) {
TreeList elements;
int i = 0;
do {
elements.push_back(parse(i++));
} while (L.nextIf(sep));
return c(TK_LIST, elements[0]->range(), std::move(elements));
}
TreeRef parseExpList() {
return parseList('(', ',', ')', [&](int i) { return parseExp(); });
}
TreeRef parseConst() {
// 'b' - boolean
// 'LL' 64-bit integer
// 'f' single-precision float
// 'i' 32-bit integer
// 'f' is default if '.' appears in the number
auto range = L.cur().range;
if (L.nextIf(TK_TRUE)) {
return c(TK_CONST, range, {d(1), s("b")});
} else if (L.nextIf(TK_FALSE)) {
return c(TK_CONST, range, {d(0), s("b")});
}
float mult = 1.0f;
while (L.nextIf('-')) {
mult *= -1.0f;
}
auto t = L.expect(TK_NUMBER);
std::string type_ident =
(t.text().find('.') == std::string::npos) ? "i" : "f";
if (L.cur().kind == TK_IDENT) {
Token type_ident_tok = L.expect(TK_IDENT);
type_ident = type_ident_tok.text();
if (type_ident != "LL" && type_ident != "f") {
throw ErrorReport(type_ident_tok)
<< "expected 'f' or 'LL' "
<< "as numeric type identifier but found '" << type_ident << "'";
}
}
return c(TK_CONST, t.range, {d(mult * t.doubleValue()), s(type_ident)});
}
TreeRef parseAttributeValue() {
int kind = L.cur().kind;
switch (kind) {
case '[':
return parseList('[', ',', ']', [&](int i) { return parseConst(); });
default:
return parseConst();
}
}
void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
L.expect('(');
if (L.cur().kind != ')') {
do {
if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
auto ident = parseIdent();
L.expect('=');
auto v = parseAttributeValue();
attributes.push_back(Attribute::create(ident->range(), ident, v));
} else {
inputs.push_back(parseExp());
}
} while (L.nextIf(','));
}
L.expect(')');
}
// OK: [a] (gather), [a:], [:a], [a:b], [:] (slice)
// Not OK: []
TreeRef parseSliceOrGather(TreeRef value) {
const auto range = L.cur().range;
L.expect('[');
// `first` will either be the gather indices, or the start of the slice.
TreeRef first, second;
// Here we can either have a colon (which starts a slice), or an expression.
// If an expression, we don't know yet if it will be a slice or a gather.
if (L.cur().kind != ':') {
first = parseExp();
if (L.nextIf(']')) {
return Gather::create(range, value, first);
} else {
first = c(TK_OPTION, range, {first});
}
} else {
first = c(TK_OPTION, range, {});
}
L.expect(':');
// Now we *may* have an expression.
if (L.cur().kind != ']') {
second = c(TK_OPTION, range, {parseExp()});
} else {
second = c(TK_OPTION, range, {});
}
L.expect(']');
return Slice::create(range, value, first, second);
}
TreeRef parseIdentList() {
return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
}
TreeRef parseParam() {
auto typ = parseType();
if (L.cur().kind != TK_IDENT && typ->trees()[0]->kind() == TK_IDENT) {
// oops, it wasn't a type but just a param without any type specified
return Param::create(
typ->range(), typ->trees()[0], c(TK_INFERRED, typ->range(), {}));
}
auto ident = parseIdent();
return Param::create(typ->range(), ident, typ);
}
// TODO: these functions should be unnecessary, but we currently do not
// emit a TK_NEWLINE before a series of TK_DEDENT tokens
// so if we see a TK_DEDENT then we know a newline must have happened and
// ignore it. The real fix is to patch the lexer so TK_NEWLINE does get
// emited before a TK_INDENT
void expectEndOfLine() {
if (L.cur().kind != TK_DEDENT)
L.expect(TK_NEWLINE);
}
bool isEndOfLine() {
return L.cur().kind == TK_NEWLINE || L.cur().kind == TK_DEDENT;
}
// 'first' has already been parsed since expressions can exist
// alone on a line:
// first[,other,lhs] = rhs
TreeRef parseAssign(TreeRef first) {
TreeRef list = parseOneOrMoreExp(first);
auto red = parseOptionalReduction();
auto rhs = parseExp();
expectEndOfLine();
return Assign::create(list->range(), list, red, rhs);
}
TreeRef parseStmt() {
switch (L.cur().kind) {
case TK_IF:
return parseIf();
case TK_WHILE:
return parseWhile();
case TK_GLOBAL: {
auto range = L.next().range;
std::vector<TreeRef> idents;
do {
idents.push_back(parseIdent());
} while (L.nextIf(','));
expectEndOfLine();
return c(TK_GLOBAL, range, std::move(idents));
}
default: {
auto r = parseExp();
if (!isEndOfLine()) {
return parseAssign(r);
} else {
expectEndOfLine();
return r;
}
}
}
}
TreeRef parseScalarType() {
switch (L.cur().kind) {
case TK_INT:
case TK_FLOAT:
case TK_LONG:
case TK_DOUBLE: {
auto t = L.next();
return c(t.kind, t.range, {});
}
default:
return parseIdent();
}
}
TreeRef parseOptionalIdentList() {
TreeRef list = nullptr;
if (L.cur().kind == '(') {
list = parseIdentList();
} else {
list = c(TK_LIST, L.cur().range, {});
}
return list;
}
TreeRef parseType() {
auto st = parseScalarType();
auto list = parseOptionalIdentList();
return TensorType::create(st->range(), st, list);
}
// 'first' has already been parsed, add the rest
// if they exist
// first[, the, rest]
TreeRef parseOneOrMoreExp(TreeRef first) {
TreeList list{first};
while (L.nextIf(',')) {
list.push_back(parseExp());
}
return List(list.back()->range(), std::move(list));
}
TreeRef parseIf() {
auto r = L.cur().range;
L.expect(TK_IF);
auto cond = parseExp();
L.expect(':');
auto true_branch = parseStatements();
auto false_branch = List(L.cur().range, {});
if (L.nextIf(TK_ELSE)) {
L.expect(':');
false_branch = parseStatements();
}
return If::create(r, cond, true_branch, false_branch);
}
TreeRef parseWhile() {
auto r = L.cur().range;
L.expect(TK_WHILE);
auto cond = parseExp();
L.expect(':');
auto body = parseStatements();
return While::create(r, cond, body);
}
TreeRef parseStatements() {
auto r = L.cur().range;
L.expect(TK_INDENT);
TreeList stmts;
while (true) {
stmts.push_back(parseStmt());
if (L.nextIf(TK_DEDENT))
break;
}
return c(TK_LIST, r, std::move(stmts));
}
TreeRef parseFunction() {
L.expect(TK_DEF);
auto name = parseIdent();
auto paramlist =
parseList('(', ',', ')', [&](int i) { return parseParam(); });
L.expect(TK_ARROW);
auto retlist =
parseList('(', ',', ')', [&](int i) { return parseParam(); });
L.expect(':');
auto stmts_list = parseStatements();
return Def::create(name->range(), name, paramlist, retlist, stmts_list);
}
Lexer& lexer() {
return L;
}
private:
// short helpers to create nodes
TreeRef d(double v) {
return Number::create(v);
}
TreeRef s(const std::string& s) {
return String::create(s);
}
TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
return Compound::create(kind, range, std::move(trees));
}
TreeRef List(const SourceRange& range, TreeList&& trees) {
return c(TK_LIST, range, std::move(trees));
}
Lexer L;
SharedParserData& shared;
};
} // namespace script
} // namespace caffe2