blob: 0cd833dc15e48823a26e80cac96ee5db9eb9072b [file] [log] [blame]
#pragma once
#include "lexer.h"
#include "tree.h"
#include "tree_views.h"
namespace torch {
namespace jit {
namespace script {
struct Parser {
explicit Parser(const std::string& str)
: L(str), shared(sharedParserData()) {}
Ident parseIdent() {
auto t = L.expect(TK_IDENT);
// whenever we parse something that has a TreeView type we always
// use its create method so that the accessors and the constructor
// of the Compound tree are in the same place.
return Ident::create(t.range, t.text());
}
TreeRef createApply(Expr expr) {
TreeList attributes;
auto range = L.cur().range;
TreeList inputs;
parseOperatorArguments(inputs, attributes);
return Apply::create(
range,
expr,
List<Expr>(makeList(range, std::move(inputs))),
List<Attribute>(makeList(range, std::move(attributes))));
}
// exp | expr, | expr, expr, ...
TreeRef parseExpOrExpTuple(int end) {
auto prefix = parseExp();
if(L.cur().kind == ',') {
std::vector<Expr> exprs = { prefix };
while(L.cur().kind != end) {
L.expect(',');
exprs.push_back(parseExp());
}
auto list = List<Expr>::create(prefix.range(), exprs);
prefix = TupleLiteral::create(list.range(), list);
}
return prefix;
}
// things like a 1.0 or a(4) that are not unary/binary expressions
// and have higher precedence than all of them
TreeRef parseBaseExp() {
TreeRef prefix;
switch (L.cur().kind) {
case TK_NUMBER: {
prefix = parseConst();
} break;
case TK_TRUE:
case TK_FALSE:
case TK_NONE: {
auto k = L.cur().kind;
auto r = L.cur().range;
prefix = c(k, r, {});
L.next();
} break;
case '(': {
L.next();
if (L.nextIf(')')) {
/// here we have the empty tuple case
std::vector<Expr> vecExpr;
List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
prefix = TupleLiteral::create(L.cur().range, listExpr);
break;
}
prefix = parseExpOrExpTuple(')');
L.expect(')');
} break;
case '[': {
auto list = parseList('[', ',', ']', &Parser::parseExp);
prefix = ListLiteral::create(list.range(), List<Expr>(list));
} break;
default: {
Ident name = parseIdent();
prefix = Var::create(name.range(), name);
} break;
}
while (true) {
if (L.nextIf('.')) {
const auto name = parseIdent();
prefix = Select::create(name.range(), Expr(prefix), Ident(name));
} else if (L.cur().kind == '(') {
prefix = createApply(Expr(prefix));
} else if (L.cur().kind == '[') {
prefix = parseSliceOrGather(prefix);
} else {
break;
}
}
return prefix;
}
TreeRef parseOptionalReduction() {
auto r = L.cur().range;
switch (L.cur().kind) {
case TK_PLUS_EQ:
case TK_MINUS_EQ:
case TK_TIMES_EQ:
case TK_DIV_EQ: {
int modifier = L.next().text()[0];
return c(modifier, r, {});
} break;
default: {
L.expect('=');
return c('=', r, {}); // no reduction
} break;
}
}
TreeRef
parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
auto cond = parseExp();
L.expect(TK_ELSE);
auto false_branch = parseExp(binary_prec);
return c(TK_IF_EXPR, range, {cond, true_branch, false_branch});
}
// parse the longest expression whose binary operators have
// precedence strictly greater than 'precedence'
// precedence == 0 will parse _all_ expressions
// this is the core loop of 'top-down precedence parsing'
Expr parseExp() { return parseExp(0); }
Expr parseExp(int precedence) {
TreeRef prefix = nullptr;
int unary_prec;
if (shared.isUnary(L.cur().kind, &unary_prec)) {
auto kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
auto unary_kind = kind == '*' ? TK_STARRED :
kind == '-' ? TK_UNARY_MINUS :
kind;
auto subexp = parseExp(unary_prec);
// fold '-' into constant numbers, so that attributes can accept
// things like -1
if(unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
} else {
prefix = c(unary_kind, pos, {subexp});
}
} else {
prefix = parseBaseExp();
}
int binary_prec;
while (shared.isBinary(L.cur().kind, &binary_prec)) {
if (binary_prec <= precedence) // not allowed to parse something which is
// not greater than 'precedenc'
break;
int kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
if (shared.isRightAssociative(kind))
binary_prec--;
// special case for trinary operator
if (kind == TK_IF) {
prefix = parseTrinary(prefix, pos, binary_prec);
continue;
}
prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
}
return Expr(prefix);
}
template<typename T>
List<T> parseList(int begin, int sep, int end, T (Parser::*parse)()) {
auto r = L.cur().range;
if (begin != TK_NOTHING)
L.expect(begin);
std::vector<T> elements;
if (L.cur().kind != end) {
do {
elements.push_back((this->*parse)());
} while (L.nextIf(sep));
}
if (end != TK_NOTHING)
L.expect(end);
return List<T>::create(r, elements);
}
Const parseConst() {
auto range = L.cur().range;
auto t = L.expect(TK_NUMBER);
return Const::create(t.range, t.text());
}
Expr parseAttributeValue() {
return parseExp();
}
void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
L.expect('(');
if (L.cur().kind != ')') {
do {
if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
auto ident = parseIdent();
L.expect('=');
auto v = parseAttributeValue();
attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
} else {
inputs.push_back(parseExp());
}
} while (L.nextIf(','));
}
L.expect(')');
}
// OK: [a] (gather), [a:], [:a], [a:b], [:] (slice)
// Not OK: []
TreeRef parseSliceOrGather(TreeRef value) {
const auto range = L.cur().range;
L.expect('[');
// `first` will either be the gather indices, or the start of the slice.
TreeRef first, second;
// Here we can either have a colon (which starts a slice), or an expression.
// If an expression, we don't know yet if it will be a slice or a gather.
if (L.cur().kind != ':') {
first = parseExp();
if (L.nextIf(']')) {
return Gather::create(range, Expr(value), Expr(first));
} else {
first = c(TK_OPTION, range, {first});
}
} else {
first = c(TK_OPTION, range, {});
}
L.expect(':');
// Now we *may* have an expression.
if (L.cur().kind != ']') {
second = c(TK_OPTION, range, {parseExp()});
} else {
second = c(TK_OPTION, range, {});
}
L.expect(']');
return Slice::create(range, Expr(value), Maybe<Expr>(first), Maybe<Expr>(second));
}
TreeRef parseParam() {
auto typ = TensorType::create(L.cur().range);
auto ident = parseIdent();
return Param::create(typ.range(), Ident(ident), Type(typ));
}
// 'first' has already been parsed since expressions can exist
// alone on a line:
// first[,other,lhs] = rhs
Assign parseAssign(List<Expr> list) {
auto red = parseOptionalReduction();
auto rhs = parseExpOrExpTuple(TK_NEWLINE);
L.expect(TK_NEWLINE);
return Assign::create(list.range(), list, AssignKind(red), Expr(rhs));
}
TreeRef parseStmt() {
switch (L.cur().kind) {
case TK_IF:
return parseIf();
case TK_WHILE:
return parseWhile();
case TK_FOR:
return parseFor();
case TK_GLOBAL: {
auto range = L.next().range;
auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseIdent);
L.expect(TK_NEWLINE);
return Global::create(range, idents);
}
case TK_RETURN: {
auto range = L.next().range;
// XXX: TK_NEWLINE makes it accept an empty list
auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp);
return Return::create(range, values);
}
default: {
List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
if (L.cur().kind != TK_NEWLINE) {
return parseAssign(exprs);
} else {
L.expect(TK_NEWLINE);
return ExprStmt::create(exprs[0].range(), exprs);
}
}
}
}
TreeRef parseOptionalIdentList() {
TreeRef list = nullptr;
if (L.cur().kind == '(') {
list = parseList('(', ',', ')', &Parser::parseIdent);
} else {
list = c(TK_LIST, L.cur().range, {});
}
return list;
}
TreeRef parseIf() {
auto r = L.cur().range;
L.expect(TK_IF);
auto cond = parseExp();
L.expect(':');
auto true_branch = parseStatements();
auto false_branch = makeList(L.cur().range, {});
if (L.nextIf(TK_ELSE)) {
L.expect(':');
false_branch = parseStatements();
}
return If::create(r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
}
TreeRef parseWhile() {
auto r = L.cur().range;
L.expect(TK_WHILE);
auto cond = parseExp();
L.expect(':');
auto body = parseStatements();
return While::create(r, Expr(cond), List<Stmt>(body));
}
TreeRef parseFor() {
auto r = L.cur().range;
L.expect(TK_FOR);
auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
L.expect(TK_IN);
auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
L.expect(':');
auto body = parseStatements();
return For::create(r, targets, itrs, body);
}
TreeRef parseStatements() {
auto r = L.cur().range;
L.expect(TK_INDENT);
TreeList stmts;
while (true) {
stmts.push_back(parseStmt());
if (L.nextIf(TK_DEDENT))
break;
}
return c(TK_LIST, r, std::move(stmts));
}
TreeRef parseFunction() {
L.expect(TK_DEF);
auto name = parseIdent();
auto paramlist = parseList('(', ',', ')', &Parser::parseParam);
L.expect(':');
auto stmts_list = parseStatements();
return Def::create(name.range(), Ident(name), List<Param>(paramlist),
List<Stmt>(stmts_list));
}
Lexer& lexer() {
return L;
}
private:
// short helpers to create nodes
TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
return Compound::create(kind, range, std::move(trees));
}
TreeRef makeList(const SourceRange& range, TreeList&& trees) {
return c(TK_LIST, range, std::move(trees));
}
Lexer L;
SharedParserData& shared;
};
} // namespace script
} // namespace jit
} // namespace torch