thread safe interned_strings
diff --git a/torch/csrc/jit/interned_strings.cpp b/torch/csrc/jit/interned_strings.cpp
index c8144e0..f63ccde 100644
--- a/torch/csrc/jit/interned_strings.cpp
+++ b/torch/csrc/jit/interned_strings.cpp
@@ -2,6 +2,7 @@
#include <stdint.h>
#include <string>
#include <unordered_map>
+#include <mutex>
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/assert.h"
@@ -17,6 +18,7 @@
#undef REGISTER_SYMBOL
}
Symbol symbol(const std::string & s) {
+ std::lock_guard<std::mutex> guard(mutex_);
auto it = string_to_sym_.find(s);
if(it != string_to_sym_.end())
return it->second;
@@ -25,15 +27,27 @@
sym_to_string_[k] = s;
return k;
}
- const std::string & string(Symbol sym) {
- auto it = sym_to_string_.find(sym);
- JIT_ASSERT(it != sym_to_string_.end());
- return it->second;
+ const char * string(Symbol sym) {
+ switch(sym) {
+ #define DEFINE_CASE(s) \
+ case k##s: return #s;
+ FORALL_BUILTIN_SYMBOLS(DEFINE_CASE)
+ #undef DEFINE_CASE
+ default:
+ return customString(sym);
+ }
}
private:
+ const char * customString(Symbol sym) {
+ std::lock_guard<std::mutex> guard(mutex_);
+ auto it = sym_to_string_.find(sym);
+ JIT_ASSERT(it != sym_to_string_.end());
+ return it->second.c_str();
+ }
std::unordered_map<std::string, Symbol> string_to_sym_;
std::unordered_map<Symbol, std::string> sym_to_string_;
Symbol next_sym;
+ std::mutex mutex_;
};
static InternedStrings & globalStrings() {
@@ -41,7 +55,7 @@
return s;
}
-const std::string & symbolToString(Symbol s) {
+const char * symbolToString(Symbol s) {
return globalStrings().string(s);
}
Symbol stringToSymbol(const std::string & s) {
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index 26f1927..2fb3a93 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -37,7 +37,7 @@
kLastSymbol, //where we start counting for new symbols
};
-const std::string & symbolToString(Symbol s);
+const char * symbolToString(Symbol s);
Symbol stringToSymbol(const std::string & s);
}}
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 6962218..112c376 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -197,12 +197,12 @@
assert(kParam == stringToSymbol("Param"));
assert(kReturn == stringToSymbol("Return"));
- assert(symbolToString(kReturn) == "Return");
+ assert(symbolToString(kReturn) == std::string("Return"));
assert(stringToSymbol("What") == kLastSymbol);
assert(stringToSymbol("What2") == kLastSymbol+1);
assert(stringToSymbol("What") == kLastSymbol);
assert(stringToSymbol("What2") == kLastSymbol+1);
- assert(symbolToString(kLastSymbol+1) == "What2");
+ assert(symbolToString(kLastSymbol+1) == std::string("What2"));
}