Cleanup some logic in pickler
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22882
Test Plan: Imported from OSS
Differential Revision: D16270332
Pulled By: zdevito
fbshipit-source-id: 714f293493965b13e471945fde11831a04875604
diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp
index e324343..8700931 100644
--- a/torch/csrc/jit/pickler.cpp
+++ b/torch/csrc/jit/pickler.cpp
@@ -96,7 +96,7 @@
std::string key = std::to_string(getStorageKey(tensor));
push<OpCode>(OpCode::BINUNICODE);
push<uint32_t>(key.size());
- pushString(key);
+ pushBytes(key);
}
push<OpCode>(OpCode::TUPLE);
push<OpCode>(OpCode::STOP);
@@ -127,15 +127,15 @@
start();
push<OpCode>(OpCode::LONG1);
// LONG1 size
- pushString("\x0a");
+ pushBytes("\x0a");
// LONG1 data
- pushString("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
+ pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
push<OpCode>(OpCode::STOP);
// Protocol Version (1001)
start();
push<OpCode>(OpCode::BININT2);
- pushString("\xe9\x03");
+ pushBytes("\xe9\x03");
push<OpCode>(OpCode::STOP);
// sys_info, this isn't actually used in de-serialization so we can leave this
@@ -145,18 +145,8 @@
push<OpCode>(OpCode::STOP);
}
-void Pickler::addIValue(const IValue& ivalue) {
- // Check if reference ivalue has been saved before
- const void* ivalue_ptr = getPointer(ivalue);
- if (ivalue_ptr) {
- auto memo_entry = memo_map_.find(ivalue_ptr);
- if (memo_entry != memo_map_.end()) {
- // This value has already been pushed, just do a BINGET
- pushBinGet(memo_entry->second);
- return;
- }
- }
-
+// unmemoized version called by addIValue
+void Pickler::addIValueImpl(const IValue& ivalue) {
if (ivalue.isTensor()) {
pushTensor(ivalue);
} else if (ivalue.isTuple()) {
@@ -172,7 +162,7 @@
push<OpCode>(OpCode::NEWFALSE);
}
} else if (ivalue.isString()) {
- pushMemoizedString(ivalue);
+ pushStringImpl(ivalue.toStringRef());
} else if (ivalue.isGenericList()) {
pushGenericList(ivalue);
} else if (ivalue.isGenericDict()) {
@@ -212,18 +202,28 @@
}
}
-/// Returns a void* uniquely identifying this IValue's data. For non-containers,
-/// returns nullptr. Also adds the ivalue to the Pickler's list of memoized
-/// IValues so the pointers are guaranteed to be valid for the Pickler's
-/// lifetime.
-const void* Pickler::getPointer(const IValue& ivalue) {
- if (ivalue.isGenericDict() || ivalue.isGenericList() || ivalue.isTuple()
- || ivalue.isString() || ivalue.isIntList() || ivalue.isTensorList()
- || ivalue.isDoubleList() || ivalue.isBoolList()) {
- return ivalue.internalToPointer();
+void Pickler::addIValue(const IValue& ivalue) {
+ // Check if reference ivalue has been saved before
+ if (ivalue.isPtrType()) {
+ const void* ptr = ivalue.internalToPointer();
+ TORCH_CHECK(
+ ptr != nullptr,
+ "Pickler cannot memoize ",
+ ivalue.tagKind(),
+ " IValue ",
+ ivalue);
+ auto memo_entry = memoized_ivalue_map_.find(ptr);
+ if (memo_entry != memoized_ivalue_map_.end()) {
+ // This value has already been pushed, just do a BINGET
+ pushBinGet(memo_entry->second);
+ return;
+ }
}
-
- return nullptr;
+ addIValueImpl(ivalue);
+ if (ivalue.isPtrType()) {
+ memoized_ivalues_.push_back(ivalue);
+ memoized_ivalue_map_[ivalue.internalToPointer()] = pushNextBinPut();
+ }
}
void Pickler::pushInt(const IValue& ivalue) {
@@ -256,28 +256,36 @@
}
}
-void Pickler::pushMemoizedString(const IValue& ivalue) {
- const auto& string = ivalue.toStringRef();
-
+// unmemoized encoding of a string
+void Pickler::pushStringImpl(const std::string& string) {
push<OpCode>(OpCode::BINUNICODE);
push<uint32_t>(string.size());
- pushString(string);
- pushMemoization(ivalue);
+ pushBytes(string);
}
void Pickler::pushString(const std::string& string) {
+ auto it = memoized_strings_map_.find(string);
+ if (it == memoized_strings_map_.end()) {
+ pushStringImpl(string);
+ memoized_strings_map_[string] = pushNextBinPut();
+ } else {
+ pushBinGet(it->second);
+ }
+}
+
+void Pickler::pushBytes(const std::string& string) {
stack_.insert(stack_.end(), string.begin(), string.end());
}
void Pickler::pushGlobal(const std::string& name_temp) {
- auto memo_entry = memoized_strings_map_.find(name_temp);
- if (memo_entry == memoized_strings_map_.end()) {
+ auto memo_entry = memoized_globals_map_.find(name_temp);
+ if (memo_entry == memoized_globals_map_.end()) {
push<OpCode>(OpCode::GLOBAL);
- pushString(name_temp);
+ pushBytes(name_temp);
- // Push BINPUT without adding anything to the memo_map_
+ // Push BINPUT without adding anything to the memoized_ivalues_
size_t memo_id = pushNextBinPut();
- memoized_strings_map_.insert({name_temp, memo_id});
+ memoized_globals_map_.insert({name_temp, memo_id});
} else {
pushBinGet(memo_entry->second);
}
@@ -309,15 +317,15 @@
// Tuple for persistent_load
push<OpCode>(OpCode::MARK);
// typename
- pushMemoizedString(std::string("storage"));
+ pushString("storage");
// data_type
std::stringstream data_type;
data_type << "torch\n" << toString(tensor.scalar_type()) << "Storage\n";
pushGlobal(data_type.str());
// root_key
- pushMemoizedString(std::to_string(getStorageKey(tensor)));
+ pushString(std::to_string(getStorageKey(tensor)));
// location
- pushMemoizedString(std::string("cpu"));
+ pushString("cpu");
// size
pushInt(tensor.numel());
// view_metadata
@@ -403,7 +411,6 @@
// Call reduce
push<OpCode>(OpCode::REDUCE);
- pushMemoization(ivalue);
}
void Pickler::pushDouble(const IValue& ivalue) {
@@ -419,7 +426,6 @@
void Pickler::pushDict(const IValue& ivalue) {
push<OpCode>(OpCode::EMPTY_DICT);
- pushMemoization(ivalue);
push<OpCode>(OpCode::MARK);
@@ -433,11 +439,6 @@
push<OpCode>(OpCode::SETITEMS);
}
-void Pickler::pushMemoization(const void* item) {
- TORCH_CHECK(item != nullptr, "Pickler cannot memoize a nullptr");
- memo_map_[item] = pushNextBinPut();
-}
-
size_t Pickler::pushNextBinPut() {
if (memo_id_ <= std::numeric_limits<uint8_t>::max()) {
push<OpCode>(OpCode::BINPUT);
@@ -452,22 +453,9 @@
return memo_id_ - 1;
}
-void Pickler::pushMemoization(const IValue& ivalue) {
- auto ptr = getPointer(ivalue);
- memoized_ivalues_.push_back(ivalue);
- TORCH_CHECK(
- ptr != nullptr,
- "Pickler cannot memoize ",
- ivalue.tagKind(),
- " IValue ",
- ivalue)
- pushMemoization(ptr);
-}
-
void Pickler::pushGenericList(const IValue& ivalue) {
auto list = ivalue.toGenericListRef();
push<OpCode>(OpCode::EMPTY_LIST);
- pushMemoization(ivalue);
push<OpCode>(OpCode::MARK);
@@ -488,7 +476,6 @@
}
push<OpCode>(OpCode::TUPLE);
- pushMemoization(ivalue);
}
std::vector<IValue> Unpickler::parse_ivalue_list() {
diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h
index 2775471..80213ff 100644
--- a/torch/csrc/jit/pickler.h
+++ b/torch/csrc/jit/pickler.h
@@ -126,6 +126,7 @@
void endTuple();
private:
+ void addIValueImpl(const IValue& ivalue);
void pushDict(const IValue& ivalue);
void pushDouble(const IValue& ivalue);
void pushGenericList(const IValue& ivalue);
@@ -134,10 +135,12 @@
void pushList(const IValue& ivalue);
void pushLiteralTensor(const IValue& ivalue);
void pushMemoization(const IValue& ivalue);
- void pushMemoizedString(const IValue& ivalue);
void pushTensor(const IValue& ivalue);
void pushTensorReference(const IValue& ivalue);
void pushTuple(const IValue& ivalue);
+ void pushString(const std::string& string);
+ // unmemoized version
+ void pushStringImpl(const std::string& string);
void pushBinGet(uint32_t memo_id);
void pushClass(PicklerClass cls);
@@ -146,8 +149,8 @@
PicklerClass cls,
const std::function<void(const IValue&)>& item_pusher);
void pushGlobal(const std::string& name);
- void pushMemoization(const void* item);
- void pushString(const std::string& string);
+ // raw string data is appended directly to the byte stream
+ void pushBytes(const std::string& string);
void pushTensorData(const at::Tensor& tensor);
// Add a BINPUT op and return the memoization id used
@@ -168,10 +171,6 @@
// Stack of opcodes/data
std::vector<char> stack_;
- // Memoization of IValues that have been written (index in table is used for
- // BINPUT opcodes) to enable shared references
- std::unordered_map<const void*, uint32_t> memo_map_;
-
// External table of tensors to serialize. If this is missing, then tensors
// are serialized directly into the pickle
std::vector<at::Tensor>* tensor_table_;
@@ -183,9 +182,17 @@
// and only memoize those)
uint32_t memo_id_ = 0;
- // When arbitrary (maybe temporary) values are saved, keep them here so they
- // can be memoized correctly
- std::vector<c10::IValue> memoized_ivalues_;
+ // Memoization of IValues that have been written (index in table is used for
+ // BINPUT opcodes) to enable shared references
+ std::unordered_map<const void*, uint32_t> memoized_ivalue_map_;
+
+ // because we de-dup ivalues based on their raw pointer address in the above
+ // map we need to keep all the memoized values alive during the pickle.
+ // Otherwise, it is possible that a raw address gets reused for another
+ // object, and we will alias it to the old object at that address.
+ std::vector<IValue> memoized_ivalues_;
+
+ std::unordered_map<std::string, uint32_t> memoized_globals_map_;
std::unordered_map<std::string, uint32_t> memoized_strings_map_;
};
diff --git a/torch/csrc/jit/source_range_serialization_impl.h b/torch/csrc/jit/source_range_serialization_impl.h
index 1e1bb10..2d0896b 100644
--- a/torch/csrc/jit/source_range_serialization_impl.h
+++ b/torch/csrc/jit/source_range_serialization_impl.h
@@ -1,6 +1,5 @@
#pragma once
-#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/source_range_serialization.h>
namespace torch {