blob: 4af5c8032edd36115c40e54a87a237674f950efe [file] [log] [blame]
#include <torch/csrc/jit/irparser.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/lexer.h>
#include <torch/csrc/jit/script/parse_string_literal.h>
#include <torch/csrc/jit/script/schema_type_parser.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace script {
struct VarWithType;
struct ParsedLiteral;
class IRParser {
friend void parseIR(const std::string& str, torch::jit::Graph* graph);
IRParser(const std::string& str, torch::jit::Graph* graph)
: L(str),
g(graph),
type_parser(L, /*parse_complete_tensor_types*/ true) {}
std::string parseVar();
VarWithType parseVarWithType();
ParsedLiteral parseScalarLiteral(Node* n);
void parse();
void parseGraphInputs();
void parseReturnOperator();
void parseBlocks(Node* parentNode);
void parseBlock(Node* parentNode);
void parseBlockInputs(Block* b);
void parseBlockOutputs(Block* b);
void parseOperatorsList(Block* b);
void parseOperator(Block* b);
void parseOperatorOutputs(std::vector<VarWithType>* outs);
std::string parseOperatorName();
void parseOperatorInputs(Node* n);
void parseAttrs(Node* n);
void parseAttr(Node* n);
void parseList(
int begin,
int sep,
int end,
const std::function<void()>& callback);
torch::jit::script::Lexer L;
torch::jit::Graph* g = nullptr;
std::unordered_map<std::string, Value*> vmap;
SchemaTypeParser type_parser;
};
struct ParsedLiteral {
ParsedLiteral() = default;
AttributeKind k = AttributeKind::t;
int64_t i = 0;
std::string s = "";
double f = 0.0;
std::vector<int64_t> is;
std::vector<std::string> ss;
std::vector<double> fs;
};
struct VarWithType {
VarWithType() = default;
std::string name;
TypePtr type;
};
void parseIR(const std::string& str, torch::jit::Graph* graph) {
torch::jit::script::IRParser p(str, graph);
p.parse();
}
VarWithType IRParser::parseVarWithType() {
VarWithType r;
r.name = parseVar();
r.type = TensorType::get();
if (L.nextIf(':')) {
auto type_alias = type_parser.parseType();
AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
r.type = type_alias.first;
}
return r;
}
std::string IRParser::parseVar() {
L.expect('%');
if (L.cur().kind == TK_IDENT) {
auto name = L.expect(TK_IDENT).text();
if (L.cur().kind == TK_NUMBER) {
auto suffix = L.expect(TK_NUMBER).text();
AT_ASSERT(suffix[0] == '.');
name += suffix;
}
return name;
} else {
return L.expect(TK_NUMBER).text();
}
}
void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
if (L.cur().kind != '%') {
return;
}
parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
outs->push_back(parseVarWithType());
});
L.expect('=');
}
// Parse string or numeric literal and return it along with its type.
ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
auto token = L.cur();
std::string str;
ParsedLiteral r;
switch (token.kind) {
case TK_STRINGLITERAL:
r.k = AttributeKind::s;
r.s = parseStringLiteral(token.range, token.text());
L.next();
return r;
case '-':
str = "-";
L.next();
L.expect(TK_NUMBER);
// Fallthrough
case TK_NUMBER:
str += L.cur().text();
if (str.find('.') != std::string::npos ||
str.find('e') != std::string::npos) {
r.k = AttributeKind::f;
r.f = std::stod(str);
} else {
r.k = AttributeKind::i;
r.i = std::stoll(str);
}
L.next();
return r;
default:
throw ErrorReport(token.range)
<< "Could not parse literal" << token.text();
}
}
/** \brief Parse attribute and add it to the node N.
*
* The function determines the attribute type (string, int, float, list of
* strings, list of ints, list of floats, and a list of tensors (currently only
* for empty lists)).
* An attribute looks like the following:
* AttrName=AttrValue
* Where AttrValue can be a list or a scalar literal, e.g.:
* size = 27
* name = "Bob"
* coefs = [1.2, 3.4, 0.6]
*/
void IRParser::parseAttr(Node* n) {
std::string attrname = L.expect(TK_IDENT).text();
L.expect('=');
if (L.cur().kind == '[') {
// list
AttributeKind k = AttributeKind::ts;
std::vector<int64_t> is;
std::vector<std::string> ss;
std::vector<double> fs;
int elem_num = 0;
parseList('[', ',', ']', [&] {
ParsedLiteral r = parseScalarLiteral(n);
switch (r.k) {
case AttributeKind::s:
ss.push_back(r.s);
AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
k = AttributeKind::ss;
break;
case AttributeKind::i:
is.push_back(r.i);
AT_ASSERT(!elem_num++ || k == AttributeKind::is);
k = AttributeKind::is;
break;
case AttributeKind::f:
fs.push_back(r.f);
AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
k = AttributeKind::fs;
break;
default:
throw ErrorReport(L.cur().range) << "Unexpected attr type";
}
});
switch (k) {
case AttributeKind::ts:
n->ts_(Symbol::attr(attrname), {});
break;
case AttributeKind::ss:
n->ss_(Symbol::attr(attrname), ss);
break;
case AttributeKind::fs:
n->fs_(Symbol::attr(attrname), fs);
break;
case AttributeKind::is:
n->is_(Symbol::attr(attrname), is);
break;
default:
throw ErrorReport(L.cur().range) << "Unexpected attr type";
}
} else {
// scalar
ParsedLiteral r = parseScalarLiteral(n);
switch (r.k) {
case AttributeKind::s:
n->s_(Symbol::attr(attrname), r.s);
break;
case AttributeKind::i:
n->i_(Symbol::attr(attrname), r.i);
break;
case AttributeKind::f:
n->f_(Symbol::attr(attrname), r.f);
break;
default:
throw ErrorReport(L.cur().range) << "Unexpected attr type";
}
return;
}
}
void IRParser::parseAttrs(Node* n) {
parseList('[', ',', ']', [&] { parseAttr(n); });
}
void IRParser::parseOperatorInputs(Node* n) {
if (L.cur().kind == '[') {
parseAttrs(n);
}
parseList('(', ',', ')', [&] {
std::string var_name = parseVar();
AT_ASSERT(vmap.count(var_name));
n->addInput(vmap[var_name]);
});
}
void IRParser::parseBlocks(Node* parentNode) {
L.expect(TK_INDENT);
while (L.cur().kind != TK_DEDENT) {
parseBlock(parentNode);
}
L.expect(TK_DEDENT);
}
void IRParser::parseBlockInputs(Block* b) {
parseList('(', ',', ')', [&] {
VarWithType v = parseVarWithType();
// If the name isn't valid, don't use it
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
vmap[v.name] = b->addInput(uniq_name);
vmap[v.name]->setType(v.type);
});
}
void IRParser::parseBlockOutputs(Block* b) {
L.expect(TK_ARROW);
parseList('(', ',', ')', [&] {
std::string var_name = parseVar();
AT_ASSERT(vmap.count(var_name));
b->registerOutput(vmap[var_name]);
});
L.expect(TK_NEWLINE);
L.expect(TK_DEDENT);
}
/** \brief Parse a block.
*
* It should look like the following:
* blockName(input1, input2, input3, ...):
* op1
* op2
* ...
* opN
* -> (output1, output2, output3, ...)
*/
void IRParser::parseBlock(Node* parentNode) {
Block* b = parentNode->addBlock();
L.expect(TK_IDENT).text(); // Block name is not used anywhere.
parseBlockInputs(b);
L.expect(':');
parseOperatorsList(b);
parseBlockOutputs(b);
}
/** \brief Parse a list of statements.
*
* It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or
* TK_ARROW.
*/
void IRParser::parseOperatorsList(Block* b) {
L.expect(TK_INDENT);
while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
parseOperator(b);
}
}
std::string IRParser::parseOperatorName() {
std::string name = L.expect(TK_IDENT).text();
L.expect(':');
L.expect(':');
name += "::" + L.expect(TK_IDENT).text();
return name;
}
/** \brief Parse a statement.
*
* It should look like the following:
* <outputs> = NodeName[<attributes>](<inputs>)
* <blocks>
* Outputs, blocks and attributes are optional.
*/
void IRParser::parseOperator(Block* b) {
// Parse lefthand side.
std::vector<VarWithType> outs;
parseOperatorOutputs(&outs);
// Parse the name and create the corresponding node in the graph.
std::string name = parseOperatorName();
Node* n = g->create(Symbol::fromQualString(name), {}, outs.size());
// Parse attributes and inputs.
parseOperatorInputs(n);
// Register outputs.
int idx = 0;
for (const VarWithType& v : outs) {
vmap[v.name] = n->outputs()[idx++];
vmap[v.name]->setType(v.type);
}
// Insert the new node into block B.
b->appendNode(n);
// If the statement has nested blocks, parse them:
if (L.cur().kind == TK_INDENT) {
parseBlocks(n);
}
L.nextIf(TK_NEWLINE);
}
void IRParser::parseGraphInputs() {
parseList('(', ',', ')', [&] {
VarWithType v = parseVarWithType();
// If the name isn't valid, don't use it
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
vmap[v.name] = g->addInput(uniq_name);
vmap[v.name]->setType(v.type);
});
}
/** \brief Parse return statement.
*
* It should look like the following:
* return (x : TypeX, y : TypeY, z, ...)
*/
void IRParser::parseReturnOperator() {
L.expect(TK_RETURN);
// Parse output names and types
parseList('(', ',', ')', [&] {
std::string var_name = parseVar();
// Outputs should already be in VMAP, otherwise we're trying to return
// undefined value.
AT_ASSERT(vmap.count(var_name));
g->registerOutput(vmap.at(var_name));
});
// Consume ending tokens
if (L.cur().kind != TK_EOF) {
L.expect(TK_NEWLINE);
L.expect(TK_DEDENT);
}
}
/** \brief Parse entire graph.
*
* It should look like the following:
* graphName (input1, input2, ... inputN):
* op1
* op2
* ...
* opN
* return (output1, output2, ... outputN)
*/
void IRParser::parse() {
// Parse graph definition, it should look like the following:
// graphName (input1, input2, ... inputN):
std::string graphName = L.expect(TK_IDENT).text();
parseGraphInputs();
L.expect(':');
// After the definition we should have a list of statements, parse it:
parseOperatorsList(g->block());
// The last statement should be return, which specifies graph outputs
parseReturnOperator();
}
void IRParser::parseList(
int begin,
int sep,
int end,
const std::function<void()>& callback) {
if (begin != TK_NOTHING) {
L.expect(begin);
}
if (L.cur().kind != end) {
do {
callback();
} while (L.nextIf(sep));
}
if (end != TK_NOTHING) {
L.expect(end);
}
}
} // namespace script
} // namespace jit
} // namespace torch