Cleanup some logic in pickler

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22882

Test Plan: Imported from OSS

Differential Revision: D16270332

Pulled By: zdevito

fbshipit-source-id: 714f293493965b13e471945fde11831a04875604
diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp
index e324343..8700931 100644
--- a/torch/csrc/jit/pickler.cpp
+++ b/torch/csrc/jit/pickler.cpp
@@ -96,7 +96,7 @@
       std::string key = std::to_string(getStorageKey(tensor));
       push<OpCode>(OpCode::BINUNICODE);
       push<uint32_t>(key.size());
-      pushString(key);
+      pushBytes(key);
     }
     push<OpCode>(OpCode::TUPLE);
     push<OpCode>(OpCode::STOP);
@@ -127,15 +127,15 @@
   start();
   push<OpCode>(OpCode::LONG1);
   // LONG1 size
-  pushString("\x0a");
+  pushBytes("\x0a");
   // LONG1 data
-  pushString("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
+  pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
   push<OpCode>(OpCode::STOP);
 
   // Protocol Version (1001)
   start();
   push<OpCode>(OpCode::BININT2);
-  pushString("\xe9\x03");
+  pushBytes("\xe9\x03");
   push<OpCode>(OpCode::STOP);
 
   // sys_info, this isn't actually used in de-serialization so we can leave this
@@ -145,18 +145,8 @@
   push<OpCode>(OpCode::STOP);
 }
 
-void Pickler::addIValue(const IValue& ivalue) {
-  // Check if reference ivalue has been saved before
-  const void* ivalue_ptr = getPointer(ivalue);
-  if (ivalue_ptr) {
-    auto memo_entry = memo_map_.find(ivalue_ptr);
-    if (memo_entry != memo_map_.end()) {
-      // This value has already been pushed, just do a BINGET
-      pushBinGet(memo_entry->second);
-      return;
-    }
-  }
-
+// unmemoized version called by addIValue
+void Pickler::addIValueImpl(const IValue& ivalue) {
   if (ivalue.isTensor()) {
     pushTensor(ivalue);
   } else if (ivalue.isTuple()) {
@@ -172,7 +162,7 @@
       push<OpCode>(OpCode::NEWFALSE);
     }
   } else if (ivalue.isString()) {
-    pushMemoizedString(ivalue);
+    pushStringImpl(ivalue.toStringRef());
   } else if (ivalue.isGenericList()) {
     pushGenericList(ivalue);
   } else if (ivalue.isGenericDict()) {
@@ -212,18 +202,28 @@
   }
 }
 
-/// Returns a void* uniquely identifying this IValue's data. For non-containers,
-/// returns nullptr. Also adds the ivalue to the Pickler's list of memoized
-/// IValues so the pointers are guaranteed to be valid for the Pickler's
-/// lifetime.
-const void* Pickler::getPointer(const IValue& ivalue) {
-  if (ivalue.isGenericDict() || ivalue.isGenericList() || ivalue.isTuple()
-    || ivalue.isString() || ivalue.isIntList() || ivalue.isTensorList()
-    || ivalue.isDoubleList() || ivalue.isBoolList()) {
-      return ivalue.internalToPointer();
+void Pickler::addIValue(const IValue& ivalue) {
+  // Check if reference ivalue has been saved before
+  if (ivalue.isPtrType()) {
+    const void* ptr = ivalue.internalToPointer();
+    TORCH_CHECK(
+        ptr != nullptr,
+        "Pickler cannot memoize ",
+        ivalue.tagKind(),
+        " IValue ",
+        ivalue);
+    auto memo_entry = memoized_ivalue_map_.find(ptr);
+    if (memo_entry != memoized_ivalue_map_.end()) {
+      // This value has already been pushed, just do a BINGET
+      pushBinGet(memo_entry->second);
+      return;
+    }
   }
-
-  return nullptr;
+  addIValueImpl(ivalue);
+  if (ivalue.isPtrType()) {
+    memoized_ivalues_.push_back(ivalue);
+    memoized_ivalue_map_[ivalue.internalToPointer()] = pushNextBinPut();
+  }
 }
 
 void Pickler::pushInt(const IValue& ivalue) {
@@ -256,28 +256,36 @@
   }
 }
 
-void Pickler::pushMemoizedString(const IValue& ivalue) {
-  const auto& string = ivalue.toStringRef();
-
+// unmemoized encoding of a string
+void Pickler::pushStringImpl(const std::string& string) {
   push<OpCode>(OpCode::BINUNICODE);
   push<uint32_t>(string.size());
-  pushString(string);
-  pushMemoization(ivalue);
+  pushBytes(string);
 }
 
 void Pickler::pushString(const std::string& string) {
+  auto it = memoized_strings_map_.find(string);
+  if (it == memoized_strings_map_.end()) {
+    pushStringImpl(string);
+    memoized_strings_map_[string] = pushNextBinPut();
+  } else {
+    pushBinGet(it->second);
+  }
+}
+
+void Pickler::pushBytes(const std::string& string) {
   stack_.insert(stack_.end(), string.begin(), string.end());
 }
 
 void Pickler::pushGlobal(const std::string& name_temp) {
-  auto memo_entry = memoized_strings_map_.find(name_temp);
-  if (memo_entry == memoized_strings_map_.end()) {
+  auto memo_entry = memoized_globals_map_.find(name_temp);
+  if (memo_entry == memoized_globals_map_.end()) {
     push<OpCode>(OpCode::GLOBAL);
-    pushString(name_temp);
+    pushBytes(name_temp);
 
-    // Push BINPUT without adding anything to the memo_map_
+    // Push BINPUT without adding anything to the memoized_ivalues_
     size_t memo_id = pushNextBinPut();
-    memoized_strings_map_.insert({name_temp, memo_id});
+    memoized_globals_map_.insert({name_temp, memo_id});
   } else {
     pushBinGet(memo_entry->second);
   }
@@ -309,15 +317,15 @@
   // Tuple for persistent_load
   push<OpCode>(OpCode::MARK);
   // typename
-  pushMemoizedString(std::string("storage"));
+  pushString("storage");
   // data_type
   std::stringstream data_type;
   data_type << "torch\n" << toString(tensor.scalar_type()) << "Storage\n";
   pushGlobal(data_type.str());
   // root_key
-  pushMemoizedString(std::to_string(getStorageKey(tensor)));
+  pushString(std::to_string(getStorageKey(tensor)));
   // location
-  pushMemoizedString(std::string("cpu"));
+  pushString("cpu");
   // size
   pushInt(tensor.numel());
   // view_metadata
@@ -403,7 +411,6 @@
 
   // Call reduce
   push<OpCode>(OpCode::REDUCE);
-  pushMemoization(ivalue);
 }
 
 void Pickler::pushDouble(const IValue& ivalue) {
@@ -419,7 +426,6 @@
 
 void Pickler::pushDict(const IValue& ivalue) {
   push<OpCode>(OpCode::EMPTY_DICT);
-  pushMemoization(ivalue);
 
   push<OpCode>(OpCode::MARK);
 
@@ -433,11 +439,6 @@
   push<OpCode>(OpCode::SETITEMS);
 }
 
-void Pickler::pushMemoization(const void* item) {
-  TORCH_CHECK(item != nullptr, "Pickler cannot memoize a nullptr");
-  memo_map_[item] = pushNextBinPut();
-}
-
 size_t Pickler::pushNextBinPut() {
   if (memo_id_ <= std::numeric_limits<uint8_t>::max()) {
     push<OpCode>(OpCode::BINPUT);
@@ -452,22 +453,9 @@
   return memo_id_ - 1;
 }
 
-void Pickler::pushMemoization(const IValue& ivalue) {
-  auto ptr = getPointer(ivalue);
-  memoized_ivalues_.push_back(ivalue);
-  TORCH_CHECK(
-      ptr != nullptr,
-      "Pickler cannot memoize ",
-      ivalue.tagKind(),
-      " IValue ",
-      ivalue)
-  pushMemoization(ptr);
-}
-
 void Pickler::pushGenericList(const IValue& ivalue) {
   auto list = ivalue.toGenericListRef();
   push<OpCode>(OpCode::EMPTY_LIST);
-  pushMemoization(ivalue);
 
   push<OpCode>(OpCode::MARK);
 
@@ -488,7 +476,6 @@
   }
 
   push<OpCode>(OpCode::TUPLE);
-  pushMemoization(ivalue);
 }
 
 std::vector<IValue> Unpickler::parse_ivalue_list() {
diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h
index 2775471..80213ff 100644
--- a/torch/csrc/jit/pickler.h
+++ b/torch/csrc/jit/pickler.h
@@ -126,6 +126,7 @@
   void endTuple();
 
  private:
+  void addIValueImpl(const IValue& ivalue);
   void pushDict(const IValue& ivalue);
   void pushDouble(const IValue& ivalue);
   void pushGenericList(const IValue& ivalue);
@@ -134,10 +135,12 @@
   void pushList(const IValue& ivalue);
   void pushLiteralTensor(const IValue& ivalue);
   void pushMemoization(const IValue& ivalue);
-  void pushMemoizedString(const IValue& ivalue);
   void pushTensor(const IValue& ivalue);
   void pushTensorReference(const IValue& ivalue);
   void pushTuple(const IValue& ivalue);
+  void pushString(const std::string& string);
+  // unmemoized version
+  void pushStringImpl(const std::string& string);
 
   void pushBinGet(uint32_t memo_id);
   void pushClass(PicklerClass cls);
@@ -146,8 +149,8 @@
       PicklerClass cls,
       const std::function<void(const IValue&)>& item_pusher);
   void pushGlobal(const std::string& name);
-  void pushMemoization(const void* item);
-  void pushString(const std::string& string);
+  // raw string data is appended directly to the byte stream
+  void pushBytes(const std::string& string);
   void pushTensorData(const at::Tensor& tensor);
 
   // Add a BINPUT op and return the memoization id used
@@ -168,10 +171,6 @@
   // Stack of opcodes/data
   std::vector<char> stack_;
 
-  // Memoization of IValues that have been written (index in table is used for
-  // BINPUT opcodes) to enable shared references
-  std::unordered_map<const void*, uint32_t> memo_map_;
-
   // External table of tensors to serialize. If this is missing, then tensors
   // are serialized directly into the pickle
   std::vector<at::Tensor>* tensor_table_;
@@ -183,9 +182,17 @@
   // and only memoize those)
   uint32_t memo_id_ = 0;
 
-  // When arbitrary (maybe temporary) values are saved, keep them here so they
-  // can be memoized correctly
-  std::vector<c10::IValue> memoized_ivalues_;
+  // Memoization of IValues that have been written (index in table is used for
+  // BINPUT opcodes) to enable shared references
+  std::unordered_map<const void*, uint32_t> memoized_ivalue_map_;
+
+  // because we de-dup ivalues based on their raw pointer address in the above
+  // map we need to keep all the memoized values alive during the pickle.
+  // Otherwise, it is possible that a raw address gets reused for another
+  // object, and we will alias it to the old object at that address.
+  std::vector<IValue> memoized_ivalues_;
+
+  std::unordered_map<std::string, uint32_t> memoized_globals_map_;
   std::unordered_map<std::string, uint32_t> memoized_strings_map_;
 };
 
diff --git a/torch/csrc/jit/source_range_serialization_impl.h b/torch/csrc/jit/source_range_serialization_impl.h
index 1e1bb10..2d0896b 100644
--- a/torch/csrc/jit/source_range_serialization_impl.h
+++ b/torch/csrc/jit/source_range_serialization_impl.h
@@ -1,6 +1,5 @@
 #pragma once
 
-#include <torch/csrc/jit/pickler.h>
 #include <torch/csrc/jit/source_range_serialization.h>
 
 namespace torch {