blob: af3a16393fe9db7a394ecc25322ba006d6726540 [file] [log] [blame]
#pragma once
#include <vector>
#include <cstdint>
#include <string>
#include <memory>
#include <vector>
#include <ATen/ATen.h>
#include "ATen/Utils.h"
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/interned_strings.h>
namespace torch { namespace jit {
constexpr int max_tensor_display_size = 10;
enum class AttributeKind {
f,fs,i,is,s,ss,t,ts,g,gs
};
static inline const char * toString(AttributeKind kind) {
static const char* names[] = {"f","fs","i","is","s","ss","t","ts","g","gs"};
JIT_ASSERT(size_t(kind) < sizeof(names)/sizeof(AttributeKind));
return names[int(kind)];
}
struct AttributeValue {
AttributeValue(Symbol name)
: name(name) {}
using Ptr = std::unique_ptr<AttributeValue>;
Symbol name;
virtual AttributeKind kind() const = 0;
virtual Ptr clone() const = 0;
virtual ~AttributeValue() = default;
};
template<typename T, AttributeKind Kind>
struct ScalarAttributeValue : public AttributeValue {
using ConstructorType = T;
using ValueType = T;
ScalarAttributeValue(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(std::move(value_)) {}
ValueType & value() {
return value_;
}
Ptr clone() const override {
return Ptr(new ScalarAttributeValue(name, value_));
}
AttributeKind kind() const override { return Kind; }
private:
ValueType value_;
};
template<typename T, AttributeKind Kind>
struct VectorAttributeValue : public AttributeValue {
using ConstructorType = std::vector<T>;
using ValueType = std::vector<T>;
VectorAttributeValue(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(std::move(value_)) {}
ValueType & value() {
return value_;
}
AttributeKind kind() const override { return Kind; }
std::unique_ptr<AttributeValue> clone() const override {
auto copy = value_;
return Ptr(new VectorAttributeValue(name, std::move(copy)));
}
private:
ValueType value_;
};
using FloatAttr = ScalarAttributeValue<double,AttributeKind::f>;
using FloatsAttr = VectorAttributeValue<double,AttributeKind::fs>;
using IntAttr = ScalarAttributeValue<int64_t,AttributeKind::i>;
using IntsAttr = VectorAttributeValue<int64_t,AttributeKind::is>;
using StringAttr = ScalarAttributeValue<std::string,AttributeKind::s>;
using StringsAttr = VectorAttributeValue<std::string,AttributeKind::ss>;
using TensorAttr = ScalarAttributeValue<at::Tensor,AttributeKind::t>;
using TensorsAttr = VectorAttributeValue<at::Tensor,AttributeKind::ts>;
struct Graph;
using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>,AttributeKind::g>;
using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs>;
struct AttributeError : public std::exception {
AttributeError(Symbol name, bool defined) {
std::stringstream ss;
if(!defined) {
ss << "required keyword attribute '" << name.toUnqualString() << "' is undefined.";
} else {
ss << "required keyword attribute '" << name.toUnqualString() << "' has the wrong type";
}
msg = ss.str();
}
const char* what() const noexcept override {
return msg.c_str();
}
private:
std::string msg;
};
// CRTP so that Node which inherits Attributes can be return for
// method chaining e.g:
// Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5);
// we return Derived* pointers because Nodes are normally held as pointers.
template<typename Derived>
struct Attributes {
Attributes() = default;
void copyAttributes(const Attributes & rhs) {
values_.clear();
for(auto & i : rhs.values_) {
values_.push_back(i->clone());
}
}
bool hasAttribute(Symbol name) const {
JIT_ASSERT(name.is_attr());
return find(name,false) != values_.end();
}
// We want direct string accessors, as it is nicer to use than
// hasAttribute(Symbol::attr("blah"))
//
// For some reason, &Attributes<Node>::hasAttribute in pybind11 is able to
// give the pybind11 metaprogramming machinery "the right type", but
// the equivalent looking lambda [](Attributes<Node>& a, const std::string&)
// doesn't work! So instead we define the methods on the class so we can
// continue using the old idiom.
bool hasAttributeS(const std::string& name) const {
return hasAttribute(Symbol::attr(name));
}
AttributeKind kindOf(Symbol name) const {
JIT_ASSERT(name.is_attr());
return (*find(name,true))->kind();
}
AttributeKind kindOfS(const std::string& name) const {
return kindOf(Symbol::attr(name));
}
Derived* removeAttribute(Symbol name) {
JIT_ASSERT(name.is_attr());
values_.erase(find(name,true));
return This();
}
Derived* removeAttributeS(const std::string& name) {
return removeAttribute(Symbol::attr(name));
}
bool hasAttributes() const {
return values_.size() > 0;
}
size_t numAttributes() const {
return values_.size();
}
// The names are returned in order, since name actually is the index.
std::vector<Symbol> attributeNames() const {
std::vector<Symbol> names;
for(auto & a : values_)
names.push_back(a->name);
return names;
}
std::vector<const char*> attributeNamesS() const {
std::vector<const char*> names;
for(auto & a : values_)
names.push_back(a->name.toUnqualString());
return names;
}
#define CREATE_ACCESSOR(Kind, method) \
Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
return set<Kind##Attr>(name,std::forward<Kind##Attr::ConstructorType>(v)); \
} \
const Kind##Attr::ValueType& method(Symbol name) const { \
return get<Kind##Attr>(name); \
}
CREATE_ACCESSOR(Float,f)
CREATE_ACCESSOR(Floats,fs)
CREATE_ACCESSOR(String,s)
CREATE_ACCESSOR(Strings,ss)
CREATE_ACCESSOR(Int,i)
CREATE_ACCESSOR(Ints,is)
CREATE_ACCESSOR(Graph,g)
CREATE_ACCESSOR(Graphs,gs)
#undef CREATE_ACCESSOR
// Our Graphs are not very const-correct, so we need to allow returning
// non-const references too
GraphAttr::ValueType& g(Symbol name) {
return get<GraphAttr>(name);
}
// does not use CREATE_ACCESSOR because we need additional asserts
Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
JIT_ASSERT(!v.defined() || !v.is_variable());
return set<TensorAttr>(name,std::forward<TensorAttr::ConstructorType>(v));
}
const TensorAttr::ValueType& t(Symbol name) const {
return get<TensorAttr>(name);
}
Derived* ts_(Symbol name, TensorsAttr::ConstructorType v) {
for(auto & t : v) {
JIT_ASSERT(!t.defined() || !t.is_variable());
}
return set<TensorsAttr>(name,std::forward<TensorsAttr::ConstructorType>(v));
}
const TensorsAttr::ValueType& ts(Symbol name) const {
return get<TensorsAttr>(name);
}
template<typename T>
static void printPrimList(std::ostream & out, const std::vector<T> & items) {
out << "[";
int i = 0;
for(auto & item : items) {
if(i++ > 0)
out << ", ";
out << item;
}
out << "]";
}
static std::string escapeString(std::string s) {
std::vector<char> search = {'\n', '\t', '\v'};
std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
for (size_t i = 0; i < search.size(); i++) {
size_t pos = s.find(search[i]);
while(pos != std::string::npos) {
s.replace(pos, 1, replace[i]);
pos = s.find(search[i], pos + 1);
}
}
return s;
}
void printValue(std::ostream & out, Symbol & name) const {
switch(kindOf(name)) {
case AttributeKind::f:
out << f(name);
break;
case AttributeKind::fs:
printPrimList(out, fs(name));
break;
case AttributeKind::i:
out << i(name);
break;
case AttributeKind::is:
printPrimList(out, is(name));
break;
case AttributeKind::s:
out << "\"" << escapeString(s(name)) << "\"";
break;
case AttributeKind::ss:
printPrimList(out,ss(name));
break;
case AttributeKind::t:
{
at::Tensor tensor = t(name);
// 1-elem tensors are usually boxed scalars, so print them like it
if (tensor.numel() == 1) {
auto scalar_tensor = at::_local_scalar(tensor.view({}));
out << "{";
if (scalar_tensor.isFloatingPoint()) {
out << scalar_tensor.toDouble();
} else {
out << scalar_tensor.toLong();
}
out << "}";
} else if (tensor.numel() <= max_tensor_display_size) {
// TODO: This is awful code. Also it doesn't work on Windows.
std::ostringstream tensor_ss;
tensor_ss << tensor;
std::string tensor_s{tensor_ss.str()};
// Remove newlines
std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
out << tensor_s;
} else {
out << "<Tensor>";
}
break;
}
case AttributeKind::ts:
out << "[<Tensors>]";
break;
case AttributeKind::g:
out << "<Graph>";
break;
case AttributeKind::gs:
out << "[<Graphs>]";
break;
}
}
private:
// UBSAN error: https://github.com/pytorch/pytorch/issues/9055
Derived* This() __ubsan_ignore_vptr__ {
return static_cast<Derived*>(this);
}
template<typename T>
Derived* set(Symbol name, typename T::ConstructorType v) {
JIT_ASSERT(name.is_attr());
auto it = find(name, false);
auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
if(it == values_.end()) {
values_.push_back(std::move(nv));
} else {
*it = std::move(nv);
}
return This();
}
template<typename T>
typename T::ValueType & get(Symbol name) const {
JIT_ASSERT(name.is_attr());
auto it = find(name, true);
auto* child = dynamic_cast<T*>(it->get());
if(child == nullptr) {
throw AttributeError(name, true);
}
return child->value();
}
using AVPtr = AttributeValue::Ptr;
// NB: For determinism, we use a vector rather than a hash map. This does
// mean that lookups are O(n), so you shouldn't use Attributes to store
// a big pile of messages.
std::vector<AVPtr> values_;
using iterator = std::vector<AVPtr>::iterator;
iterator find(Symbol name, bool required) {
JIT_ASSERT(name.is_attr());
auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
return v->name == name;
});
if(required && it == values_.end()) {
throw AttributeError(name, false);
}
JIT_ASSERT(!required || it != values_.end());
return it;
}
using const_iterator = std::vector<AVPtr>::const_iterator;
const_iterator find(Symbol name, bool required) const {
JIT_ASSERT(name.is_attr());
auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
return v->name == name;
});
if(required && it == values_.end()) {
throw AttributeError(name, false);
}
JIT_ASSERT(!required || it != values_.end());
return it;
}
};
}}