blob: c5083083fe633323ee4ea78ca88dc445b5b203b9 [file] [log] [blame]
#pragma once
#include <memory>
#include <vector>
#include "caffe2/contrib/script/lexer.h"
namespace caffe2 {
namespace script {
// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
// Rather than have a full class hierarchy for all TC statements,
// Trees are a slight variation of Lisp S-expressions.
// for instance the expression a*b+1 is represented as:
// (+ (* (ident a) (ident b)) (const 1))
// Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
// define stringValue() and doubleValue().
// Everything else is a Compound object, which has a 'kind' that is a token from
// Lexer.h's TokenKind enum, and contains a list of subtrees.
// Like TokenKind single-character operators like '+' are representing using the
// character itself, so add.kind() == '+'.
// Compound objects are also always associated with a SourceRange for
// reporting error message.
// Memory management of trees is done using shared_ptr.
struct Tree;
using TreeRef = std::shared_ptr<Tree>;
using TreeList = std::vector<TreeRef>;
static const TreeList empty_trees = {};
struct Tree : std::enable_shared_from_this<Tree> {
Tree(int kind_) : kind_(kind_) {}
int kind() const {
return kind_;
}
virtual bool isAtom() const {
return true;
}
virtual const SourceRange& range() const {
throw std::runtime_error("is an Atom");
}
virtual double doubleValue() const {
throw std::runtime_error("not a TK_NUMBER");
}
virtual const std::string& stringValue() const {
throw std::runtime_error("not a TK_STRING");
}
virtual bool boolValue() const {
throw std::runtime_error("not a TK_BOOL");
}
virtual const TreeList& trees() const {
return empty_trees;
}
const TreeRef& tree(size_t i) const {
return trees().at(i);
}
virtual TreeRef map(std::function<TreeRef(TreeRef)> /*fn*/) {
return shared_from_this();
}
template <typename... Args>
void match(int k, Args&... args) {
matchD(k, "unknown", 0, args...);
}
template <typename... Args>
void matchD(int k, const char* filename, int lineno, Args&... args) {
if (kind() != k) {
std::stringstream ss;
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
<< "' but found '" << kind() << "'\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
std::initializer_list<TreeRef*> vars = {&args...};
if (vars.size() > trees().size()) {
std::stringstream ss;
ss << filename << ":" << lineno << ": trying to match " << vars.size()
<< " variables against " << trees().size() << " values in list.\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
size_t i = 0;
for (TreeRef* v : vars) {
*v = trees()[i++];
}
}
virtual ~Tree() {}
private:
int kind_;
};
struct String : public Tree {
String(const std::string& value_) : Tree(TK_STRING), value_(value_) {}
virtual const std::string& stringValue() const override {
return value_;
}
template <typename... Args>
static TreeRef create(Args&&... args) {
return std::make_shared<String>(std::forward<Args>(args)...);
}
private:
std::string value_;
};
struct Number : public Tree {
Number(double value_) : Tree(TK_NUMBER), value_(value_) {}
virtual double doubleValue() const override {
return value_;
}
template <typename... Args>
static TreeRef create(Args&&... args) {
return std::make_shared<Number>(std::forward<Args>(args)...);
}
private:
double value_;
};
struct Bool : public Tree {
Bool(bool value_) : Tree(TK_BOOL), value_(value_) {}
virtual double doubleValue() const override {
return value_;
}
template <typename... Args>
static TreeRef create(Args&&... args) {
return std::make_shared<Bool>(std::forward<Args>(args)...);
}
private:
bool value_;
};
static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
for (auto t : others) {
if (t->isAtom())
continue;
size_t s = std::min(c.start(), t->range().start());
size_t e = std::max(c.end(), t->range().end());
c = SourceRange(c.file_ptr(), s, e);
}
return c;
}
struct Compound : public Tree {
Compound(int kind, const SourceRange& range_) : Tree(kind), range_(range_) {}
Compound(int kind, const SourceRange& range_, TreeList&& trees_)
: Tree(kind),
range_(mergeRanges(range_, trees_)),
trees_(std::move(trees_)) {}
virtual const TreeList& trees() const override {
return trees_;
}
static TreeRef
create(int kind, const SourceRange& range_, TreeList&& trees_) {
return std::make_shared<Compound>(kind, range_, std::move(trees_));
}
virtual bool isAtom() const override {
return false;
}
virtual TreeRef map(std::function<TreeRef(TreeRef)> fn) override {
TreeList trees_;
for (auto& t : trees()) {
trees_.push_back(fn(t));
}
return Compound::create(kind(), range(), std::move(trees_));
}
const SourceRange& range() const override {
return range_;
}
private:
SourceRange range_;
TreeList trees_;
};
// tree pretty printer
struct pretty_tree {
pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
const TreeRef& tree;
size_t col;
std::unordered_map<TreeRef, std::string> flat_strings;
const std::string& get_flat(const TreeRef& t) {
auto it = flat_strings.find(t);
if (it != flat_strings.end())
return it->second;
std::stringstream out;
switch (t->kind()) {
case TK_NUMBER:
out << t->doubleValue();
break;
case TK_STRING:
out << t->stringValue();
break;
default:
out << "(" << kindToString(t->kind());
for (auto e : t->trees()) {
out << " " << get_flat(e);
}
out << ")";
break;
}
auto it_ = flat_strings.emplace(t, out.str());
return it_.first->second;
}
void print(std::ostream& out, const TreeRef& t, int indent) {
const std::string& s = get_flat(t);
if (indent + s.size() < col || t->isAtom()) {
out << s;
return;
}
std::string k = kindToString(t->kind());
out << "(" << k;
for (auto e : t->trees()) {
out << "\n" << std::string(indent + 2, ' ');
print(out, e, indent + 2);
}
out << ")";
}
};
static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
t_.print(out, t_.tree, 0);
return out << std::endl;
}
static inline std::ostream& operator<<(std::ostream& out, TreeRef t) {
return out << pretty_tree(t);
}
} // namespace script
} // namespace caffe2