|  | #include <vector> | 
|  | #include <stdint.h> | 
|  | #include <string> | 
|  | #include <unordered_map> | 
|  | #include <mutex> | 
|  | #include <sstream> | 
|  | #include "ATen/optional.h" | 
|  | #include "torch/csrc/assertions.h" | 
|  | #include "torch/csrc/jit/interned_strings.h" | 
|  |  | 
|  | namespace torch { namespace jit { | 
|  |  | 
|  | // TODO: programatically generate this | 
|  | inline std::string valid_namespaces_str() { | 
|  | return "'onnx', 'prim', 'aten', 'attr', 'scope'"; | 
|  | } | 
|  |  | 
|  | // Alas, this is only constexpr in C++14 | 
|  | const unique_t unique_start = | 
|  | std::max({static_cast<unique_t>(aten::_keys::num_symbols), | 
|  | static_cast<unique_t>(prim::_keys::num_symbols), | 
|  | static_cast<unique_t>(onnx::_keys::num_symbols), | 
|  | static_cast<unique_t>(attr::_keys::num_symbols)}); | 
|  |  | 
|  | SymbolNamespace parseNamespace(const std::string & s) { | 
|  | size_t colon_pos = s.find(':'); | 
|  | if (colon_pos == std::string::npos) { | 
|  | std::ostringstream ss; | 
|  | ss << "Symbol: illegal unqualified string '" << s << "'; " | 
|  | << "all symbols must be qualified, e.g., 'ns::" << s << "'. " | 
|  | << "Valid namespaces are: " << valid_namespaces_str(); | 
|  | throw std::runtime_error(ss.str()); | 
|  | } | 
|  | if (colon_pos == 0) { | 
|  | std::ostringstream ss; | 
|  | ss << "Symbol: illegal leading colon in '" << s << "'; " | 
|  | << "all symbols must have a non-empty namespace. " | 
|  | << "Valid namespaces are: " << valid_namespaces_str(); | 
|  | throw std::runtime_error(ss.str()); | 
|  | } | 
|  | // attr::x | 
|  | //     ^___ colon_pos | 
|  | //        ^___ s.size() | 
|  | if (colon_pos + 2 > s.size()) { | 
|  | std::ostringstream ss; | 
|  | ss << "Symbol: underlong string '" << s << "'; " | 
|  | << "namespace must be followed by double colon and a " | 
|  | << "non-empty string."; | 
|  | throw std::runtime_error(ss.str()); | 
|  | } | 
|  | if (s[colon_pos + 1] != ':') { | 
|  | std::ostringstream ss; | 
|  | ss << "Symbol: invalid use of colons in '" << s << "'; " | 
|  | << "namespace must be followed by double colon, not a" | 
|  | << "single colon."; | 
|  | throw std::runtime_error(ss.str()); | 
|  | } | 
|  | auto ns = s.substr(0, colon_pos); | 
|  | if (ns == "aten") { | 
|  | return SymbolNamespace::aten; | 
|  | } else if (ns == "prim") { | 
|  | return SymbolNamespace::prim; | 
|  | } else if (ns == "onnx") { | 
|  | return SymbolNamespace::onnx; | 
|  | } else if (ns == "attr") { | 
|  | return SymbolNamespace::attr; | 
|  | } else if (ns == "scope") { | 
|  | return SymbolNamespace::scope; | 
|  | } else { | 
|  | std::ostringstream ss; | 
|  | ss << "Symbol: invalid namespace in '" << s << "'. " | 
|  | << "Valid namespaces are: " << valid_namespaces_str(); | 
|  | throw std::runtime_error(ss.str()); | 
|  | } | 
|  | } | 
|  |  | 
|  | struct InternedStrings { | 
|  | InternedStrings() | 
|  | : next_uniq(unique_start) { | 
|  | #define REGISTER_SYMBOL(ns, s) \ | 
|  | string_to_sym_[#ns "::" #s] = ns::s; \ | 
|  | sym_to_string_[ns::s] = #ns "::" #s; | 
|  | FORALL_BUILTIN_SYMBOLS(REGISTER_SYMBOL) | 
|  | #undef REGISTER_SYMBOL | 
|  | } | 
|  | Symbol symbol(const std::string & s, at::optional<SymbolNamespace> known_ns = at::nullopt) { | 
|  | JIT_ASSERT(!known_ns || parseNamespace(s) == known_ns); | 
|  | std::lock_guard<std::mutex> guard(mutex_); | 
|  | auto it = string_to_sym_.find(s); | 
|  | if(it != string_to_sym_.end()) | 
|  | return it->second; | 
|  | unique_t uniq = next_uniq++; | 
|  | // TODO: Doing the parsing while we hold the lock is a bit naughty | 
|  | SymbolNamespace ns = known_ns ? *known_ns : parseNamespace(s); | 
|  | Symbol sym(ns, uniq); | 
|  | string_to_sym_[s] = sym; | 
|  | sym_to_string_[sym] = s; | 
|  | return sym; | 
|  | } | 
|  | const char * string(Symbol sym) { | 
|  | // Builtin Symbols are also in the maps, but | 
|  | // we can bypass the need to acquire a lock | 
|  | // to read the map for Builtins because we already | 
|  | // know their string value | 
|  | switch(sym) { | 
|  | #define DEFINE_CASE(ns, s) \ | 
|  | case ns::s: return #ns "::" #s; | 
|  | FORALL_BUILTIN_SYMBOLS(DEFINE_CASE) | 
|  | #undef DEFINE_CASE | 
|  | default: | 
|  | return customString(sym); | 
|  | } | 
|  | } | 
|  | private: | 
|  | const char * customString(Symbol sym) { | 
|  | std::lock_guard<std::mutex> guard(mutex_); | 
|  | auto it = sym_to_string_.find(sym); | 
|  | JIT_ASSERT(it != sym_to_string_.end()); | 
|  | return it->second.c_str(); | 
|  | } | 
|  | std::unordered_map<std::string, Symbol> string_to_sym_; | 
|  | std::unordered_map<Symbol, std::string> sym_to_string_; | 
|  | unique_t next_uniq; | 
|  | std::mutex mutex_; | 
|  | }; | 
|  |  | 
|  | static InternedStrings & globalStrings() { | 
|  | static InternedStrings s; | 
|  | return s; | 
|  | } | 
|  |  | 
|  | Symbol Symbol::fromQualString(const std::string & s) { | 
|  | return globalStrings().symbol(s); | 
|  | } | 
|  |  | 
|  | const char* symbolNamespaceString(SymbolNamespace ns) { | 
|  | switch (ns) { | 
|  | case SymbolNamespace::onnx: return "onnx"; | 
|  | case SymbolNamespace::prim: return "prim"; | 
|  | case SymbolNamespace::aten: return "aten"; | 
|  | case SymbolNamespace::attr: return "attr"; | 
|  | case SymbolNamespace::scope: return "scope"; | 
|  | // NB: throwing an exception here causes gcc -O3 to produce far worse code | 
|  | default: return ""; | 
|  | } | 
|  | } | 
|  |  | 
|  | // NB: I really wanted to define this as std::strlen(symbolNamespaceString(ns)), | 
|  | // but gcc -O3 doesn't seem to be smart enough to push the std::strlen into | 
|  | // the switch statement. | 
|  | size_t symbolNamespaceLength(SymbolNamespace ns) { | 
|  | switch (ns) { | 
|  | case SymbolNamespace::onnx: return 4; | 
|  | case SymbolNamespace::prim: return 4; | 
|  | case SymbolNamespace::aten: return 4; | 
|  | case SymbolNamespace::attr: return 4; | 
|  | case SymbolNamespace::scope: return 5; | 
|  | default: return 0; | 
|  | } | 
|  | } | 
|  |  | 
|  | const char * Symbol::toUnqualString() const { | 
|  | return globalStrings().string(*this) + symbolNamespaceLength(ns()) + 2 /* double colon */; | 
|  | } | 
|  |  | 
|  | const char * Symbol::toQualString() const { | 
|  | return globalStrings().string(*this); | 
|  | } | 
|  |  | 
|  | const char * Symbol::toDisplayString() const { | 
|  | // TODO: Make this actually return something that's "user friendly". | 
|  | // The trouble is that, for this to be usable in printf-style assert | 
|  | // statements, this has to return a const char* (whose lifetime is | 
|  | // global), so we can't actually assemble a string on the fly. | 
|  | return globalStrings().string(*this); | 
|  | } | 
|  |  | 
|  | std::string qualifyString(SymbolNamespace ns, const std::string & s) { | 
|  | std::string qual_s; | 
|  | qual_s.reserve(s.size() + 2 /* double colon */ + symbolNamespaceLength(ns)); | 
|  | qual_s.append(symbolNamespaceString(ns)); | 
|  | qual_s.append("::"); | 
|  | qual_s.append(s); | 
|  | return qual_s; | 
|  | } | 
|  |  | 
|  | Symbol::Symbol(SymbolNamespace ns, const std::string & s) | 
|  | : value(globalStrings().symbol(qualifyString(ns, s), ns)) { | 
|  | } | 
|  |  | 
|  | }} |