blob: fd3746e399b0b910dc1760e37a96f39343e97087 [file] [log] [blame]
#include <vector>
#include <stdint.h>
#include <string>
#include <unordered_map>
#include <mutex>
#include <sstream>
#include "ATen/optional.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/jit/interned_strings.h"
namespace torch { namespace jit {
// TODO: programatically generate this
inline std::string valid_namespaces_str() {
return "'onnx', 'prim', 'aten', 'attr', 'scope'";
}
// Alas, this is only constexpr in C++14
const unique_t unique_start =
std::max({static_cast<unique_t>(aten::_keys::num_symbols),
static_cast<unique_t>(prim::_keys::num_symbols),
static_cast<unique_t>(onnx::_keys::num_symbols),
static_cast<unique_t>(attr::_keys::num_symbols)});
SymbolNamespace parseNamespace(const std::string & s) {
size_t colon_pos = s.find(':');
if (colon_pos == std::string::npos) {
std::ostringstream ss;
ss << "Symbol: illegal unqualified string '" << s << "'; "
<< "all symbols must be qualified, e.g., 'ns::" << s << "'. "
<< "Valid namespaces are: " << valid_namespaces_str();
throw std::runtime_error(ss.str());
}
if (colon_pos == 0) {
std::ostringstream ss;
ss << "Symbol: illegal leading colon in '" << s << "'; "
<< "all symbols must have a non-empty namespace. "
<< "Valid namespaces are: " << valid_namespaces_str();
throw std::runtime_error(ss.str());
}
// attr::x
// ^___ colon_pos
// ^___ s.size()
if (colon_pos + 2 > s.size()) {
std::ostringstream ss;
ss << "Symbol: underlong string '" << s << "'; "
<< "namespace must be followed by double colon and a "
<< "non-empty string.";
throw std::runtime_error(ss.str());
}
if (s[colon_pos + 1] != ':') {
std::ostringstream ss;
ss << "Symbol: invalid use of colons in '" << s << "'; "
<< "namespace must be followed by double colon, not a"
<< "single colon.";
throw std::runtime_error(ss.str());
}
auto ns = s.substr(0, colon_pos);
if (ns == "aten") {
return SymbolNamespace::aten;
} else if (ns == "prim") {
return SymbolNamespace::prim;
} else if (ns == "onnx") {
return SymbolNamespace::onnx;
} else if (ns == "attr") {
return SymbolNamespace::attr;
} else if (ns == "scope") {
return SymbolNamespace::scope;
} else {
std::ostringstream ss;
ss << "Symbol: invalid namespace in '" << s << "'. "
<< "Valid namespaces are: " << valid_namespaces_str();
throw std::runtime_error(ss.str());
}
}
struct InternedStrings {
InternedStrings()
: next_uniq(unique_start) {
#define REGISTER_SYMBOL(ns, s) \
string_to_sym_[#ns "::" #s] = ns::s; \
sym_to_string_[ns::s] = #ns "::" #s;
FORALL_BUILTIN_SYMBOLS(REGISTER_SYMBOL)
#undef REGISTER_SYMBOL
}
Symbol symbol(const std::string & s, at::optional<SymbolNamespace> known_ns = at::nullopt) {
JIT_ASSERT(!known_ns || parseNamespace(s) == known_ns);
std::lock_guard<std::mutex> guard(mutex_);
auto it = string_to_sym_.find(s);
if(it != string_to_sym_.end())
return it->second;
unique_t uniq = next_uniq++;
// TODO: Doing the parsing while we hold the lock is a bit naughty
SymbolNamespace ns = known_ns ? *known_ns : parseNamespace(s);
Symbol sym(ns, uniq);
string_to_sym_[s] = sym;
sym_to_string_[sym] = s;
return sym;
}
const char * string(Symbol sym) {
// Builtin Symbols are also in the maps, but
// we can bypass the need to acquire a lock
// to read the map for Builtins because we already
// know their string value
switch(sym) {
#define DEFINE_CASE(ns, s) \
case ns::s: return #ns "::" #s;
FORALL_BUILTIN_SYMBOLS(DEFINE_CASE)
#undef DEFINE_CASE
default:
return customString(sym);
}
}
private:
const char * customString(Symbol sym) {
std::lock_guard<std::mutex> guard(mutex_);
auto it = sym_to_string_.find(sym);
JIT_ASSERT(it != sym_to_string_.end());
return it->second.c_str();
}
std::unordered_map<std::string, Symbol> string_to_sym_;
std::unordered_map<Symbol, std::string> sym_to_string_;
unique_t next_uniq;
std::mutex mutex_;
};
static InternedStrings & globalStrings() {
static InternedStrings s;
return s;
}
Symbol Symbol::fromQualString(const std::string & s) {
return globalStrings().symbol(s);
}
const char* symbolNamespaceString(SymbolNamespace ns) {
switch (ns) {
case SymbolNamespace::onnx: return "onnx";
case SymbolNamespace::prim: return "prim";
case SymbolNamespace::aten: return "aten";
case SymbolNamespace::attr: return "attr";
case SymbolNamespace::scope: return "scope";
// NB: throwing an exception here causes gcc -O3 to produce far worse code
default: return "";
}
}
// NB: I really wanted to define this as std::strlen(symbolNamespaceString(ns)),
// but gcc -O3 doesn't seem to be smart enough to push the std::strlen into
// the switch statement.
size_t symbolNamespaceLength(SymbolNamespace ns) {
switch (ns) {
case SymbolNamespace::onnx: return 4;
case SymbolNamespace::prim: return 4;
case SymbolNamespace::aten: return 4;
case SymbolNamespace::attr: return 4;
case SymbolNamespace::scope: return 5;
default: return 0;
}
}
const char * Symbol::toUnqualString() const {
return globalStrings().string(*this) + symbolNamespaceLength(ns()) + 2 /* double colon */;
}
const char * Symbol::toQualString() const {
return globalStrings().string(*this);
}
const char * Symbol::toDisplayString() const {
// TODO: Make this actually return something that's "user friendly".
// The trouble is that, for this to be usable in printf-style assert
// statements, this has to return a const char* (whose lifetime is
// global), so we can't actually assemble a string on the fly.
return globalStrings().string(*this);
}
std::string Symbol::domainString() const {
return domain_prefix + symbolNamespaceString(ns());
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
std::string qualString = d.substr(domain_prefix.size()) + "::" + s;
return fromQualString(qualString);
}
std::string qualifyString(SymbolNamespace ns, const std::string & s) {
std::string qual_s;
qual_s.reserve(s.size() + 2 /* double colon */ + symbolNamespaceLength(ns));
qual_s.append(symbolNamespaceString(ns));
qual_s.append("::");
qual_s.append(s);
return qual_s;
}
Symbol::Symbol(SymbolNamespace ns, const std::string & s)
: value(globalStrings().symbol(qualifyString(ns, s), ns)) {
}
}}