blob: 3d69b3fc98c1577c8f6e23f8cb52c91c7098aae0 [file] [log] [blame]
#pragma once
#include <assert.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "torch/csrc/jit/source_location.h"
namespace torch {
namespace jit {
namespace script {
// single character tokens are just the character itself '+'
// multi-character tokens need an entry here
// if the third entry is not the empty string, it is used
// in the lexer to match this token.
// These kinds are also used in Tree.h as the kind of the AST node.
// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
// lexer.
#define TC_FORALL_TOKEN_KINDS(_) \
_(TK_EOF, "eof", "") \
_(TK_WHITESPACE, "whitespace", "") \
_(TK_NUMBER, "number", "") \
_(TK_NEWLINE, "newline", "") \
_(TK_INDENT, "indent", "") \
_(TK_DEDENT, "dedent", "") \
_(TK_WHERE, "where", "where") \
_(TK_FLOAT, "float", "float") \
_(TK_DOUBLE, "double", "double") \
_(TK_LONG, "long", "long") \
_(TK_INT, "int", "int") \
_(TK_DEF, "def", "def") \
_(TK_EQUIVALENT, "equivalent", "<=>") \
_(TK_IDENT, "ident", "") \
_(TK_STRING, "string", "") \
_(TK_CONST, "const", "") \
_(TK_LIST, "list", "") \
_(TK_OPTION, "option", "") \
_(TK_APPLY, "apply", "") \
_(TK_COMPREHENSION, "comprehension", "") \
_(TK_TENSOR_TYPE, "tensor_type", "") \
_(TK_RANGE_CONSTRAINT, "range_constraint", "") \
_(TK_PARAM, "param", "") \
_(TK_INFERRED, "inferred", "") \
_(TK_BOOL, "bool", "") \
_(TK_ACCESS, "access", "") \
_(TK_ASSIGN, "assign", "") \
_(TK_ATTRIBUTE, "attribute", "") \
_(TK_IF, "if", "if") \
_(TK_ELSE, "else", "else") \
_(TK_ELIF, "elif", "elif") \
_(TK_WHILE, "while", "while") \
_(TK_EXPR_STMT, "expression statement", "") \
_(TK_RETURN, "return", "return") \
_(TK_NE, "ne", "!=") \
_(TK_EQ, "eq", "==") \
_(TK_LE, "le", "<=") \
_(TK_GE, "ge", ">=") \
_(TK_IF_EXPR, "if", "") \
_(TK_TRUE, "True", "True") \
_(TK_FALSE, "False", "False") \
_(TK_AND, "and", "and") \
_(TK_OR, "or", "or") \
_(TK_NOT, "not", "not") \
_(TK_CAST, "cast", "") \
_(TK_PLUS_EQ, "+=", "+=") \
_(TK_MINUS_EQ, "-=", "-=") \
_(TK_TIMES_EQ, "*=", "*=") \
_(TK_DIV_EQ, "/=", "/=") \
_(TK_GLOBAL, "global", "global") \
_(TK_BUILT_IN, "built-in", "") \
_(TK_SLICE, "slice", "") \
_(TK_VAR, "variable", "") \
_(TK_GATHER, "gather", "") \
_(TK_NOTHING, "nothing", "") \
_(TK_LIST_LITERAL, "list-literal", "") \
_(TK_FOR, "for", "for") \
_(TK_IN, "in", "in")
static const char* valid_single_char_tokens = "+-*/()[]:,={}><.";
enum TokenKind {
// we use characters to represent themselves so skip all valid characters
// before
// assigning enum values to multi-char tokens.
TK_DUMMY_START = 256,
#define DEFINE_TOKEN(tok, _, _2) tok,
TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
#undef DEFINE_TOKEN
};
std::string kindToString(int kind);
int stringToKind(std::string str);
// nested hash tables that indicate char-by-char what is a valid token.
struct TokenTrie;
using TokenTrieRef = std::unique_ptr<TokenTrie>;
struct TokenTrie {
TokenTrie() : kind(0) {}
void insert(const char* str, int tok) {
if (*str == '\0') {
assert(kind == 0);
kind = tok;
return;
}
auto& entry = children[*str];
if (entry == nullptr) {
entry.reset(new TokenTrie());
}
entry->insert(str + 1, tok);
}
int kind; // 0 == invalid token
std::unordered_map<char, TokenTrieRef> children;
};
// stuff that is shared against all TC lexers/parsers and is initialized only
// once.
struct SharedParserData {
SharedParserData() : head(new TokenTrie()) {
// listed in increasing order of precedence
std::vector<std::vector<int>> binary_ops = {
{TK_IF},
{TK_AND, TK_OR},
{}, // reserve a level for unary not
{'<', '>', TK_EQ, TK_LE, TK_GE, TK_NE},
{'+', '-'},
{'*', '/'},
};
std::vector<std::vector<int>> unary_ops = {
{'-'},
};
std::stringstream ss;
for (const char* c = valid_single_char_tokens; *c; c++) {
std::string str(1, *c);
head->insert(str.c_str(), *c);
}
#define ADD_CASE(tok, _, tokstring) \
if (*tokstring != '\0') { \
head->insert(tokstring, tok); \
}
TC_FORALL_TOKEN_KINDS(ADD_CASE)
#undef ADD_CASE
// precedence starts at 1 so that there is always a 0 precedence
// less than any other precedence
int prec = 1;
for (auto& group : binary_ops) {
for (auto& element : group) {
binary_prec[element] = prec;
}
prec++;
}
// unary ops
for (auto& group : unary_ops) {
for (auto& element : group) {
unary_prec[element] = prec;
}
prec++;
}
// add unary not separately because it slots into the precedence of
// binary operators
unary_prec[TK_NOT] = binary_prec[TK_AND] + 1;
}
// 1. skip whitespace
// 2. handle comment or newline
//
bool isNumber(const std::string& str, size_t start, size_t* len) {
char first = str[start];
// strtod allows numbers to start with + or - or nan or inf
// http://en.cppreference.com/w/cpp/string/byte/strtof
// but we want only the number part, otherwise 1+3 will turn into two
// adjacent numbers in the lexer
if (first == '-' || first == '+' || isalpha(first))
return false;
const char* startptr = str.c_str() + start;
char* endptr;
std::strtod(startptr, &endptr);
*len = endptr - startptr;
return *len > 0;
}
bool isblank(int n) {
return isspace(n) && n != '\n';
}
// find the longest match of str.substring(pos) against a token, return true
// if successful
// filling in kind, start,and len
bool match(
const std::string& str,
size_t pos,
bool continuation, // are we inside a scope where newlines don't count
// (e.g. inside parens)
bool whitespace_token, // should we treat whitespace as a token
int* kind,
size_t* start,
size_t* len) {
*start = pos;
// skip whitespace
while (pos < str.size() && isblank(str[pos]))
pos++;
// special handling
if (pos < str.size()) {
if (str[pos] == '#') {
// skip comments
while (pos < str.size() && str[pos] != '\n')
pos++;
// tail call, handle whitespace and more comments
return match(
str, pos, continuation, whitespace_token, kind, start, len);
}
if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
!whitespace_token) {
return match(str, pos + 2, continuation, false, kind, start, len);
}
if (str[pos] == '\n') {
return match(
str, pos + 1, continuation, !continuation, kind, start, len);
}
}
if (pos == str.size()) {
*kind = TK_EOF;
*start = pos;
*len = 0;
return true;
}
// invariant: the next token is not whitespace or newline
if (whitespace_token) {
*kind = TK_WHITESPACE;
*len = pos - *start;
return true;
}
*start = pos;
// check for a valid number
if (isNumber(str, pos, len)) {
*kind = TK_NUMBER;
return true;
}
// check for either an ident or a token
// ident tracks whether what we have scanned so far could be an identifier
// matched indicates if we have found any match.
bool matched = false;
bool ident = true;
TokenTrie* cur = head.get();
for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
ident = ident && validIdent(i, str[pos + i]);
if (ident) {
matched = true;
*len = i + 1;
*kind = TK_IDENT;
}
// check for token second, so that e.g. 'max' matches the token TK_MAX
// rather the
// identifier 'max'
if (cur) {
auto it = cur->children.find(str[pos + i]);
cur = (it == cur->children.end()) ? nullptr : it->second.get();
if (cur && cur->kind != 0) {
matched = true;
*len = i + 1;
*kind = cur->kind;
}
}
}
return matched;
}
bool isUnary(int kind, int* prec) {
auto it = unary_prec.find(kind);
if (it != unary_prec.end()) {
*prec = it->second;
return true;
}
return false;
}
bool isBinary(int kind, int* prec) {
auto it = binary_prec.find(kind);
if (it != binary_prec.end()) {
*prec = it->second;
return true;
}
return false;
}
bool isRightAssociative(int kind) {
switch (kind) {
case '?':
return true;
default:
return false;
}
}
private:
bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
TokenTrieRef head;
std::unordered_map<int, int>
unary_prec; // map from token to its unary precedence
std::unordered_map<int, int>
binary_prec; // map from token to its binary precedence
};
SharedParserData& sharedParserData();
// a range of a shared string 'file_' with functions to help debug by highlight
// that
// range.
struct SourceRange : public SourceLocation {
SourceRange(
const std::shared_ptr<std::string>& file_,
size_t start_,
size_t end_)
: file_(file_), start_(start_), end_(end_) {}
const std::string text() const {
return file().substr(start(), end() - start());
}
size_t size() const {
return end() - start();
}
virtual void highlight(std::ostream& out) const override {
const std::string& str = file();
size_t begin = start();
size_t end = start();
while (begin > 0 && str[begin - 1] != '\n')
--begin;
while (end < str.size() && str[end] != '\n')
++end;
out << str.substr(0, end) << "\n";
out << std::string(start() - begin, ' ');
size_t len = std::min(size(), end - start());
out << std::string(len, '~')
<< (len < size() ? "... <--- HERE" : " <--- HERE");
out << str.substr(end);
if (str.size() > 0 && str.back() != '\n')
out << "\n";
}
const std::string& file() const {
return *file_;
}
const std::shared_ptr<std::string>& file_ptr() const {
return file_;
}
size_t start() const {
return start_;
}
size_t end() const {
return end_;
}
private:
std::shared_ptr<std::string> file_;
size_t start_;
size_t end_;
};
struct Token {
int kind;
SourceRange range;
Token(int kind, const SourceRange& range) : kind(kind), range(range) {}
std::string text() {
return range.text();
}
std::string kindString() const {
return kindToString(kind);
}
};
struct Lookahead {
Lookahead(const Token& t) : t(t) {}
Token t;
bool valid = false;
size_t repeat = 0;
};
struct Lexer {
std::shared_ptr<std::string> file;
explicit Lexer(const std::string& str)
: file(std::make_shared<std::string>(str)),
pos(0),
cur_(TK_EOF, SourceRange(file, 0, 0)),
lookahead_(cur_),
repeat(0),
nesting(0),
shared(sharedParserData()) {
auto first_indent = lexRaw(true);
indent_stack.push_back(first_indent.range.size());
next();
}
Token next() {
Token r = cur_;
if (repeat > 0) {
repeat--;
} else if (lookahead_.valid) {
lookahead_.valid = false;
repeat = lookahead_.repeat;
cur_ = lookahead_.t;
} else {
std::tie(cur_, repeat) = lex();
}
return r;
}
bool nextIf(int kind) {
if (cur_.kind != kind)
return false;
next();
return true;
}
[[noreturn]] void reportError(const std::string& what) {
reportError(what, cur_);
}[[noreturn]] void reportError(const std::string& what, const Token& t) {
std::stringstream ss;
ss << what << ":\n";
t.range.highlight(ss);
throw std::runtime_error(ss.str());
}
[[noreturn]] void expected(const std::string& what, const Token& t) {
std::stringstream ss;
ss << "expected " << what << " but found '" << t.kindString()
<< "' here:\n";
t.range.highlight(ss);
throw std::runtime_error(ss.str());
}[[noreturn]] void expected(const std::string& what) {
expected(what, cur_);
}
Token expect(int kind) {
if (cur_.kind != kind) {
expected(kindToString(kind));
}
return next();
}
Token& lookahead() {
if (!lookahead_.valid) {
lookahead_.valid = true;
std::tie(lookahead_.t, lookahead_.repeat) = lex();
}
return lookahead_.t;
}
Token& cur() {
return cur_;
}
private:
// token, number of times to repeat it
std::pair<Token, int> lex() {
auto r = lexRaw();
int repeat = 0;
switch (r.kind) {
case '(':
case '[':
case '{':
nesting++;
break;
case ')':
case ']':
case '}':
nesting--;
break;
case TK_WHITESPACE: {
int depth = r.range.size();
if (depth > indent_stack.back()) {
indent_stack.push_back(depth);
r.kind = TK_INDENT;
} else if (depth == indent_stack.back()) {
r.kind = TK_NEWLINE;
} else {
while (indent_stack.back() != depth) {
indent_stack.pop_back();
repeat++;
if (indent_stack.size() == 0) {
reportError("invalid ident level", r);
}
}
repeat--; // first repeat is this return
r.kind = TK_DEDENT;
}
} break;
case TK_EOF:
if (indent_stack.size() > 1) {
r.kind = TK_DEDENT;
indent_stack.pop_back();
}
break;
default:
break;
}
return std::make_pair(r, repeat);
}
Token lexRaw(bool whitespace_token = false) {
int kind;
size_t start;
size_t length;
assert(file);
if (!shared.match(
*file,
pos,
nesting > 0,
whitespace_token,
&kind,
&start,
&length)) {
expected(
"a valid token",
Token((*file)[start], SourceRange(file, start, start + 1)));
}
auto t = Token(kind, SourceRange(file, start, start + length));
pos = start + length;
return t;
}
size_t pos;
Token cur_;
Lookahead lookahead_;
size_t repeat; // how many times to repeat the current token until we continue
size_t nesting; // depth of ( [ { nesting...
std::vector<int> indent_stack; // stack of identation level of blocks
SharedParserData& shared;
};
} // namespace script
} // namespace jit
} // namespace torch