| #pragma once |
| #include "lexer.h" |
| #include "tree.h" |
| #include "tree_views.h" |
| |
| namespace torch { |
| namespace jit { |
| namespace script { |
| |
| |
| Decl mergeTypesFromTypeComment(Decl decl, Decl type_annotation_decl, bool is_method) { |
| auto expected_num_annotations = decl.params().size(); |
| if (is_method) { |
| // `self` argument |
| expected_num_annotations -= 1; |
| } |
| if (expected_num_annotations != type_annotation_decl.params().size()) { |
| throw ErrorReport(type_annotation_decl.range()) << "Number of type annotations (" |
| << type_annotation_decl.params().size() << ") did not match the number of " |
| << "function parameters (" << expected_num_annotations << ")"; |
| } |
| auto old = decl.params(); |
| auto _new = type_annotation_decl.params(); |
| // Merge signature idents and ranges with annotation types |
| |
| std::vector<Param> new_params; |
| size_t i = is_method ? 1 : 0; |
| size_t j = 0; |
| if (is_method) { |
| new_params.push_back(old[0]); |
| } |
| for (; i < decl.params().size(); ++i, ++j) { |
| new_params.push_back(Param::create(old[i].range(), old[i].ident(), _new[j].type())); |
| } |
| return Decl::create(decl.range(), List<Param>::create(decl.range(), new_params), type_annotation_decl.return_type()); |
| } |
| |
| struct Parser { |
| explicit Parser(const std::string& str) |
| : L(str), shared(sharedParserData()) {} |
| |
| Ident parseIdent() { |
| auto t = L.expect(TK_IDENT); |
| // whenever we parse something that has a TreeView type we always |
| // use its create method so that the accessors and the constructor |
| // of the Compound tree are in the same place. |
| return Ident::create(t.range, t.text()); |
| } |
| TreeRef createApply(Expr expr) { |
| TreeList attributes; |
| auto range = L.cur().range; |
| TreeList inputs; |
| parseOperatorArguments(inputs, attributes); |
| return Apply::create( |
| range, |
| expr, |
| List<Expr>(makeList(range, std::move(inputs))), |
| List<Attribute>(makeList(range, std::move(attributes)))); |
| } |
| // exp | expr, | expr, expr, ... |
| TreeRef parseExpOrExpTuple(int end) { |
| auto prefix = parseExp(); |
| if(L.cur().kind == ',') { |
| std::vector<Expr> exprs = { prefix }; |
| while(L.cur().kind != end) { |
| L.expect(','); |
| exprs.push_back(parseExp()); |
| } |
| auto list = List<Expr>::create(prefix.range(), exprs); |
| prefix = TupleLiteral::create(list.range(), list); |
| } |
| return prefix; |
| } |
| // things like a 1.0 or a(4) that are not unary/binary expressions |
| // and have higher precedence than all of them |
| TreeRef parseBaseExp() { |
| TreeRef prefix; |
| switch (L.cur().kind) { |
| case TK_NUMBER: { |
| prefix = parseConst(); |
| } break; |
| case TK_TRUE: |
| case TK_FALSE: |
| case TK_NONE: { |
| auto k = L.cur().kind; |
| auto r = L.cur().range; |
| prefix = c(k, r, {}); |
| L.next(); |
| } break; |
| case '(': { |
| L.next(); |
| if (L.nextIf(')')) { |
| /// here we have the empty tuple case |
| std::vector<Expr> vecExpr; |
| List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr); |
| prefix = TupleLiteral::create(L.cur().range, listExpr); |
| break; |
| } |
| prefix = parseExpOrExpTuple(')'); |
| L.expect(')'); |
| } break; |
| case '[': { |
| auto list = parseList('[', ',', ']', &Parser::parseExp); |
| prefix = ListLiteral::create(list.range(), List<Expr>(list)); |
| } break; |
| case TK_STRINGLITERAL: { |
| prefix = parseStringLiteral(); |
| } break; |
| default: { |
| Ident name = parseIdent(); |
| prefix = Var::create(name.range(), name); |
| } break; |
| } |
| while (true) { |
| if (L.nextIf('.')) { |
| const auto name = parseIdent(); |
| prefix = Select::create(name.range(), Expr(prefix), Ident(name)); |
| } else if (L.cur().kind == '(') { |
| prefix = createApply(Expr(prefix)); |
| } else if (L.cur().kind == '[') { |
| prefix = parseSubscript(prefix); |
| } else { |
| break; |
| } |
| } |
| return prefix; |
| } |
| TreeRef parseOptionalReduction() { |
| auto r = L.cur().range; |
| switch (L.cur().kind) { |
| case TK_PLUS_EQ: |
| case TK_MINUS_EQ: |
| case TK_TIMES_EQ: |
| case TK_DIV_EQ: { |
| int modifier = L.next().text()[0]; |
| return c(modifier, r, {}); |
| } break; |
| default: { |
| L.expect('='); |
| return c('=', r, {}); // no reduction |
| } break; |
| } |
| } |
| TreeRef |
| parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) { |
| auto cond = parseExp(); |
| L.expect(TK_ELSE); |
| auto false_branch = parseExp(binary_prec); |
| return c(TK_IF_EXPR, range, {cond, true_branch, false_branch}); |
| } |
| // parse the longest expression whose binary operators have |
| // precedence strictly greater than 'precedence' |
| // precedence == 0 will parse _all_ expressions |
| // this is the core loop of 'top-down precedence parsing' |
| Expr parseExp() { return parseExp(0); } |
| Expr parseExp(int precedence) { |
| TreeRef prefix = nullptr; |
| int unary_prec; |
| if (shared.isUnary(L.cur().kind, &unary_prec)) { |
| auto kind = L.cur().kind; |
| auto pos = L.cur().range; |
| L.next(); |
| auto unary_kind = kind == '*' ? TK_STARRED : |
| kind == '-' ? TK_UNARY_MINUS : |
| kind; |
| auto subexp = parseExp(unary_prec); |
| // fold '-' into constant numbers, so that attributes can accept |
| // things like -1 |
| if(unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) { |
| prefix = Const::create(subexp.range(), "-" + Const(subexp).text()); |
| } else { |
| prefix = c(unary_kind, pos, {subexp}); |
| } |
| } else { |
| prefix = parseBaseExp(); |
| } |
| int binary_prec; |
| while (shared.isBinary(L.cur().kind, &binary_prec)) { |
| if (binary_prec <= precedence) // not allowed to parse something which is |
| // not greater than 'precedence' |
| break; |
| |
| int kind = L.cur().kind; |
| auto pos = L.cur().range; |
| L.next(); |
| if (shared.isRightAssociative(kind)) |
| binary_prec--; |
| |
| // special case for trinary operator |
| if (kind == TK_IF) { |
| prefix = parseTrinary(prefix, pos, binary_prec); |
| continue; |
| } |
| |
| prefix = c(kind, pos, {prefix, parseExp(binary_prec)}); |
| } |
| return Expr(prefix); |
| } |
| template<typename T> |
| List<T> parseList(int begin, int sep, int end, T (Parser::*parse)()) { |
| auto r = L.cur().range; |
| if (begin != TK_NOTHING) |
| L.expect(begin); |
| std::vector<T> elements; |
| if (L.cur().kind != end) { |
| do { |
| elements.push_back((this->*parse)()); |
| } while (L.nextIf(sep)); |
| } |
| if (end != TK_NOTHING) |
| L.expect(end); |
| return List<T>::create(r, elements); |
| } |
| |
| Const parseConst() { |
| auto range = L.cur().range; |
| auto t = L.expect(TK_NUMBER); |
| return Const::create(t.range, t.text()); |
| } |
| |
| bool isCharCount(char c, const std::string& str, size_t start, int len) { |
| //count checks from [start, start + len) |
| return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len; |
| } |
| |
| std::string parseString(const SourceRange& range, const std::string &str) { |
| int quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1; |
| auto ret_str = str.substr(quote_len, str.size() - quote_len * 2); |
| size_t pos = ret_str.find('\\'); |
| while(pos != std::string::npos) { |
| //invariant: pos has to escape a character because it is a valid string |
| char c = ret_str[pos + 1]; |
| switch (ret_str[pos + 1]) { |
| case '\\': |
| case '\'': |
| case '\"': |
| case '\n': |
| break; |
| case 'a': |
| c = '\a'; |
| break; |
| case 'b': |
| c = '\b'; |
| break; |
| case 'f': |
| c = '\f'; |
| break; |
| case 'n': |
| c = '\n'; |
| break; |
| case 'v': |
| c = '\v'; |
| break; |
| default: |
| throw ErrorReport(range) << " octal and hex escaped sequences are not supported"; |
| } |
| ret_str.replace(pos, /* num to erase */ 2, /* num copies */ 1, c); |
| pos = ret_str.find('\\', pos + 1); |
| } |
| return ret_str; |
| } |
| |
| StringLiteral parseStringLiteral() { |
| auto range = L.cur().range; |
| std::stringstream ss; |
| while(L.cur().kind == TK_STRINGLITERAL) |
| ss << parseString(L.cur().range, L.next().text()); |
| return StringLiteral::create(range, ss.str()); |
| } |
| |
| Expr parseAttributeValue() { |
| return parseExp(); |
| } |
| |
| void parseOperatorArguments(TreeList& inputs, TreeList& attributes) { |
| L.expect('('); |
| if (L.cur().kind != ')') { |
| do { |
| if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') { |
| auto ident = parseIdent(); |
| L.expect('='); |
| auto v = parseAttributeValue(); |
| attributes.push_back(Attribute::create(ident.range(), Ident(ident), v)); |
| } else { |
| inputs.push_back(parseExp()); |
| } |
| } while (L.nextIf(',')); |
| } |
| L.expect(')'); |
| } |
| |
| // Parse expr's of the form [a:], [:b], [a:b], [:] |
| Expr parseSubscriptExp() { |
| TreeRef first, second; |
| auto range = L.cur().range; |
| if (L.cur().kind != ':') { |
| first = parseExp(); |
| } |
| if (L.nextIf(':')) { |
| if (L.cur().kind != ',' && L.cur().kind != ']') { |
| second = parseExp(); |
| } |
| auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first)) : Maybe<Expr>::create(range); |
| auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second)) : Maybe<Expr>::create(range); |
| return SliceExpr::create(range, maybe_first, maybe_second); |
| } else { |
| return Expr(first); |
| } |
| } |
| |
| TreeRef parseSubscript(TreeRef value) { |
| const auto range = L.cur().range; |
| |
| auto subscript_exprs = parseList('[', ',', ']', &Parser::parseSubscriptExp); |
| return Subscript::create(range, Expr(value), subscript_exprs); |
| } |
| |
| TreeRef parseParam() { |
| auto ident = parseIdent(); |
| TreeRef type; |
| if (L.nextIf(':')) { |
| type = parseExp(); |
| } else { |
| type = Var::create(L.cur().range, Ident::create(L.cur().range, "Tensor")); |
| } |
| return Param::create(type->range(), Ident(ident), Expr(type)); |
| } |
| |
| Param parseBareTypeAnnotation() { |
| auto type = parseExp(); |
| return Param::create(type.range(), Ident::create(type.range(), ""), type); |
| } |
| |
| TreeRef parseTypeComment(bool parse_full_line=false) { |
| auto range = L.cur().range; |
| if (parse_full_line) { |
| L.expect(TK_TYPE_COMMENT); |
| } |
| auto param_types = parseList('(', ',', ')', &Parser::parseBareTypeAnnotation); |
| TreeRef return_type; |
| if (L.nextIf(TK_ARROW)) { |
| return_type = Maybe<Expr>::create(L.cur().range, parseExp()); |
| } else { |
| return_type = Maybe<Expr>::create(L.cur().range); |
| } |
| if (!parse_full_line) |
| L.expect(TK_NEWLINE); |
| return Decl::create(range, param_types, Maybe<Expr>(return_type)); |
| } |
| |
| // 'first' has already been parsed since expressions can exist |
| // alone on a line: |
| // first[,other,lhs] = rhs |
| Assign parseAssign(List<Expr> list) { |
| auto red = parseOptionalReduction(); |
| auto rhs = parseExpOrExpTuple(TK_NEWLINE); |
| L.expect(TK_NEWLINE); |
| return Assign::create(list.range(), list, AssignKind(red), Expr(rhs)); |
| } |
| TreeRef parseStmt() { |
| switch (L.cur().kind) { |
| case TK_IF: |
| return parseIf(); |
| case TK_WHILE: |
| return parseWhile(); |
| case TK_FOR: |
| return parseFor(); |
| case TK_GLOBAL: { |
| auto range = L.next().range; |
| auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseIdent); |
| L.expect(TK_NEWLINE); |
| return Global::create(range, idents); |
| } |
| case TK_RETURN: { |
| auto range = L.next().range; |
| // XXX: TK_NEWLINE makes it accept an empty list |
| auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp); |
| return Return::create(range, values); |
| } |
| default: { |
| List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); |
| if (L.cur().kind != TK_NEWLINE) { |
| return parseAssign(exprs); |
| } else { |
| L.expect(TK_NEWLINE); |
| return ExprStmt::create(exprs[0].range(), exprs); |
| } |
| } |
| } |
| } |
| TreeRef parseOptionalIdentList() { |
| TreeRef list = nullptr; |
| if (L.cur().kind == '(') { |
| list = parseList('(', ',', ')', &Parser::parseIdent); |
| } else { |
| list = c(TK_LIST, L.cur().range, {}); |
| } |
| return list; |
| } |
| TreeRef parseIf() { |
| auto r = L.cur().range; |
| L.expect(TK_IF); |
| auto cond = parseExp(); |
| L.expect(':'); |
| auto true_branch = parseStatements(); |
| auto false_branch = makeList(L.cur().range, {}); |
| if (L.nextIf(TK_ELSE)) { |
| L.expect(':'); |
| false_branch = parseStatements(); |
| } |
| return If::create(r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch)); |
| } |
| TreeRef parseWhile() { |
| auto r = L.cur().range; |
| L.expect(TK_WHILE); |
| auto cond = parseExp(); |
| L.expect(':'); |
| auto body = parseStatements(); |
| return While::create(r, Expr(cond), List<Stmt>(body)); |
| } |
| TreeRef parseFor() { |
| auto r = L.cur().range; |
| L.expect(TK_FOR); |
| auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); |
| L.expect(TK_IN); |
| auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); |
| L.expect(':'); |
| auto body = parseStatements(); |
| return For::create(r, targets, itrs, body); |
| } |
| |
| TreeRef parseStatements(bool expect_indent=true) { |
| auto r = L.cur().range; |
| if (expect_indent) |
| L.expect(TK_INDENT); |
| TreeList stmts; |
| for (size_t i=0; ; ++i) { |
| auto stmt = parseStmt(); |
| stmts.push_back(stmt); |
| if (L.nextIf(TK_DEDENT)) |
| break; |
| } |
| return c(TK_LIST, r, std::move(stmts)); |
| } |
| Decl parseDecl() { |
| auto paramlist = parseList('(', ',', ')', &Parser::parseParam); |
| // Parse return type annotation |
| TreeRef return_type; |
| if (L.nextIf(TK_ARROW)) { |
| // Exactly one expression for return type annotation |
| return_type = Maybe<Expr>::create(L.cur().range, parseExp()); |
| } else { |
| // Default to returning single tensor. TODO: better sentinel value? |
| return_type = Maybe<Expr>::create(L.cur().range); |
| } |
| L.expect(':'); |
| return Decl::create(paramlist.range(), List<Param>(paramlist), Maybe<Expr>(return_type)); |
| } |
| |
| TreeRef parseFunction(bool is_method) { |
| L.expect(TK_DEF); |
| auto name = parseIdent(); |
| auto decl = parseDecl(); |
| |
| // Handle type annotations specified in a type comment as the first line of |
| // the function. |
| L.expect(TK_INDENT); |
| if (L.nextIf(TK_TYPE_COMMENT)) { |
| auto type_annotation_decl = Decl(parseTypeComment()); |
| decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method); |
| } |
| |
| auto stmts_list = parseStatements(false); |
| return Def::create(name.range(), Ident(name), Decl(decl), |
| List<Stmt>(stmts_list)); |
| } |
| Lexer& lexer() { |
| return L; |
| } |
| |
| private: |
| // short helpers to create nodes |
| TreeRef c(int kind, const SourceRange& range, TreeList&& trees) { |
| return Compound::create(kind, range, std::move(trees)); |
| } |
| TreeRef makeList(const SourceRange& range, TreeList&& trees) { |
| return c(TK_LIST, range, std::move(trees)); |
| } |
| Lexer L; |
| SharedParserData& shared; |
| }; |
| } // namespace script |
| } // namespace jit |
| } // namespace torch |