Revert D16675418: [jit] Add Pickler C++ API
Differential Revision:
D16675418
Original commit changeset: 76543c81ac67
fbshipit-source-id: f0249d16d363c4ecbceecd1bf610dc280e659cc0
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 456f143..d4d77ec 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -371,7 +371,6 @@
${TORCH_SRC_DIR}/csrc/jit/graph_executor.cpp
${TORCH_SRC_DIR}/csrc/jit/import_source.cpp
${TORCH_SRC_DIR}/csrc/jit/import.cpp
- ${TORCH_SRC_DIR}/csrc/jit/pickle.cpp
${TORCH_SRC_DIR}/csrc/jit/import_export_helpers.cpp
${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp
index a008d3e..3d65cd9 100644
--- a/test/cpp/api/jit.cpp
+++ b/test/cpp/api/jit.cpp
@@ -1,7 +1,6 @@
#include <gtest/gtest.h>
#include <torch/jit.h>
-#include <torch/script.h>
#include <torch/types.h>
#include <string>
@@ -111,18 +110,3 @@
0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
}
-
-TEST(TorchScriptTest, TestPickle) {
- torch::IValue float_value(2.3);
-
- // TODO: when tensors are stored in the pickle, delete this
- std::vector<at::Tensor> tensor_table;
- auto data = torch::jit::pickle(float_value, &tensor_table);
-
- std::vector<torch::IValue> ivalues =
- torch::jit::unpickle(data.data(), data.size());
-
- double diff = ivalues.at(0).toDouble() - float_value.toDouble();
- double eps = 0.0001;
- ASSERT_TRUE(diff < eps && diff > -eps);
-}
diff --git a/tools/build_variables.py b/tools/build_variables.py
index d1df889..096f483 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -64,7 +64,6 @@
"torch/csrc/jit/pickler.cpp",
"torch/csrc/jit/graph_executor.cpp",
"torch/csrc/jit/import.cpp",
- "torch/csrc/jit/pickle.cpp",
"torch/csrc/jit/import_export_helpers.cpp",
"torch/csrc/jit/interpreter.cpp",
"torch/csrc/jit/ir.cpp",
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 28ec783..071aaa7 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -10,7 +10,7 @@
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
-#include <torch/csrc/jit/pickle.h>
+#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <caffe2/core/types.h>
@@ -784,8 +784,16 @@
void ScriptModuleSerializer::writePickleArchive(
const std::string& name,
const std::vector<IValue>& ivalues) {
- auto data = pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table_);
- writer_.writeRecord(name, data.data(), data.size(), /*compress=*/true);
+ Pickler pickler(&tensor_table_);
+ pickler.protocol();
+ pickler.startTuple();
+ for (const IValue& ivalue : ivalues) {
+ pickler.pushIValue(ivalue);
+ }
+ pickler.endTuple();
+ pickler.stop();
+ writer_.writeRecord(name, pickler.stack().data(), pickler.stack().size(),
+ /*compress=*/true);
}
void ScriptModuleSerializer::convertModule(
@@ -870,7 +878,8 @@
module_def->mutable_torchscript_debug_arena();
SourceRangePickler source_range_pickler;
- const auto& range_data = source_range_pickler.pickle(source_ranges);
+ source_range_pickler.pickle(source_ranges);
+ const auto& range_data = source_range_pickler.get_data();
std::stringstream debug_filename;
debug_filename << "debug/" << module_name.str() << ".pkl";
writer_.writeRecord(
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 3684f59..09c3d8b 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -7,7 +7,7 @@
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/pickle.h>
+#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/source_range_serialization_impl.h>
@@ -75,7 +75,7 @@
script::Module convertModule(const torch::ModuleDef& module_def);
void loadTensorTable(torch::ModelDef* model_def);
- IValue loadPickleArchive(const std::string& name);
+ std::vector<IValue> loadPickleArchive(const std::string& name);
void importCallback(const std::string& qualifier);
void moduleSetState(const script::Module& module, IValue state);
@@ -142,12 +142,8 @@
}
loadTensorTable(&model_def);
- if (model_def.proto_version() == 2) {
- auto list = loadPickleArchive("attributes.pkl").toGenericList();
- pickled_ivalues_.insert(pickled_ivalues_.end(), list.begin(), list.end());
- } else if (model_def.proto_version() >= 3) {
- pickled_ivalues_ =
- loadPickleArchive("attributes.pkl").toTuple()->elements();
+ if (model_def.proto_version() >= 2) {
+ pickled_ivalues_ = loadPickleArchive("attributes.pkl");
}
return convertModule(module_def);
@@ -160,12 +156,12 @@
}
}
-IValue ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
+std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
at::DataPtr attributes_ptr;
size_t attributes_size;
std::tie(attributes_ptr, attributes_size) = reader_->getRecord(name);
- auto ivalue = unpickle(
- reinterpret_cast<const char*>(attributes_ptr.get()),
+ Unpickler unpickler(
+ attributes_ptr.get(),
attributes_size,
&tensor_table_,
[&](const c10::QualifiedName& qn) {
@@ -173,7 +169,7 @@
return c10::StrongTypePtr(
compilation_unit_, compilation_unit_->get_class(qn));
});
- return ivalue;
+ return unpickler.parse_ivalue_list();
}
at::Tensor ScriptModuleDeserializer::loadTensor(
diff --git a/torch/csrc/jit/pickle.cpp b/torch/csrc/jit/pickle.cpp
deleted file mode 100644
index e2f173f..0000000
--- a/torch/csrc/jit/pickle.cpp
+++ /dev/null
@@ -1,82 +0,0 @@
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <ATen/core/ivalue.h>
-#include <torch/csrc/jit/pickle.h>
-#include <torch/csrc/jit/pickler.h>
-
-
-namespace torch {
-namespace jit {
-
-void pickle(
- std::function<void(const char*, size_t)> writer,
- const IValue& ivalue,
- std::vector<at::Tensor>* tensor_table) {
- Pickler pickler(std::move(writer), tensor_table);
-
- if (tensor_table == nullptr) {
- // No tensor table provided, so tensors will be stored directly in the blob.
- // Add torch.save metadata so these tensors can be de-serialized later
- pickler.torchSaveStart();
- }
-
- pickler.protocol();
- pickler.pushIValue(ivalue);
- pickler.stop();
-
- if (tensor_table == nullptr) {
- // No tensor table provided, so tensors will be stored directly in the blob.
- // Add torch.save metadata so these tensors can be de-serialized later
- pickler.torchSaveStop();
- }
-}
-
-std::vector<char> pickle(
- const IValue& ivalue,
- std::vector<at::Tensor>* tensor_table) {
- std::vector<char> data;
-
- pickle(
- [&](const char* bytes, size_t len) {
- data.insert(data.end(), bytes, bytes + len);
- },
- ivalue,
- tensor_table);
-
- return data;
-}
-
-IValue unpickle(
- std::function<void(char*, size_t)> reader,
- std::function<bool()> bounds_checker,
- std::vector<at::Tensor>* tensor_table,
- ClassResolver class_resolver) {
- Unpickler unpickler(
- std::move(reader),
- std::move(bounds_checker),
- tensor_table,
- std::move(class_resolver));
- return unpickler.parse_ivalue();
-}
-
-IValue unpickle(
- const char* data,
- size_t size,
- std::vector<at::Tensor>* tensor_table,
- ClassResolver class_resolver) {
- size_t bytes_read = 0;
- return unpickle(
- [&](char* buffer, size_t len) {
- // Copy len bytes into buffer
- const char* start = data + bytes_read;
- std::memcpy(buffer, start, len);
- bytes_read += len;
- },
- [&]() {
- return bytes_read < size;
- },
- tensor_table,
- std::move(class_resolver));
-}
-
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/pickle.h b/torch/csrc/jit/pickle.h
deleted file mode 100644
index 41a1032..0000000
--- a/torch/csrc/jit/pickle.h
+++ /dev/null
@@ -1,79 +0,0 @@
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <ATen/core/ivalue.h>
-#include <torch/csrc/jit/pickler.h>
-
-
-namespace torch {
-namespace jit {
-
-/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
-///
-/// If present, `tensor_table` is a pointer to a table in which tensors that
-/// are contained within `ivalue` are stored, and the bytes returned by the
-/// pickler will only include references to these tensors in the table. This can
-/// be used to keep the binary blob size small.
-/// If not provided, tensors are stored in the same byte stream as the pickle
-/// data, similar to `torch.save()` in eager Python.
-///
-/// Pickled values can be loaded in Python and C++:
-/// \rst
-/// .. code-block:: cpp
-///
-/// torch::IValue float_value(2.3);
-///
-/// // TODO: when tensors are stored in the pickle, delete this
-/// std::vector<at::Tensor> tensor_table;
-/// auto data = torch::jit::pickle(float_value, &tensor_table);
-///
-/// std::vector<torch::IValue> ivalues =
-/// torch::jit::unpickle(data.data(), data.size());
-///
-/// .. code-block:: python
-///
-/// values = torch.load('data.pkl')
-/// print(values)
-///
-/// \endrst
-TORCH_API std::vector<char> pickle(
- const IValue& ivalue,
- std::vector<at::Tensor>* tensor_table = nullptr);
-
-/// Pickle an IValue by calling a function to handle writing the data.
-///
-/// `writer` is a function that takes in a pointer to a chunk of memory and its
-/// size and consumes it.
-///
-/// See `jit::pickle` for more details.
-TORCH_API void pickle(
- std::function<void(const char* data_start, size_t data_len)> writer,
- const IValue& ivalue,
- std::vector<at::Tensor>* tensor_table = nullptr);
-
-/// `reader` is a function that takes in a size to read from some pickled
-/// binary. `reader` should remember where it last read.
-///
-/// `bounds_checker` is a function that returns `true` if the reader can read
-/// more data, and `false` if it cannot (i.e. if a stream has hit its end of
-/// file)
-///
-/// See `torch::pickle` for details.
-TORCH_API IValue unpickle(
- std::function<const char*(size_t)> reader,
- std::function<bool()> bounds_chcker,
- std::vector<at::Tensor>* tensor_table = nullptr,
- ClassResolver class_resolver = nullptr);
-
-/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
-///
-/// If any `torch::IValue`s in the pickled data are `Object`s, then a
-/// `class_resolver` function must be provided.
-///
-/// See `torch::pickle` for details.
-TORCH_API IValue unpickle(
- const char* data,
- size_t size,
- std::vector<at::Tensor>* tensor_table = nullptr,
- ClassResolver class_resolver = nullptr);
-
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp
index 01a7703..bc6be68 100644
--- a/torch/csrc/jit/pickler.cpp
+++ b/torch/csrc/jit/pickler.cpp
@@ -52,6 +52,10 @@
}
}
+const std::vector<char>& Pickler::stack() {
+ return stack_;
+}
+
void Pickler::protocol() {
push<OpCode>(OpCode::PROTO);
push<uint8_t>(PROTOCOL_VERSION);
@@ -85,7 +89,6 @@
push<uint32_t>(key.size());
pushBytes(key);
}
-
push<OpCode>(OpCode::TUPLE);
stop();
@@ -93,7 +96,7 @@
for (const auto& data : tensor_data_) {
// first dump size
push<size_t>(data.numel());
- writer_(data.data(), data.sizeInBytes());
+ stack_.insert(stack_.end(), data.data(), data.data() + data.sizeInBytes());
}
}
@@ -301,7 +304,7 @@
}
void Pickler::pushBytes(const std::string& string) {
- writer_(string.data(), string.size());
+ stack_.insert(stack_.end(), string.begin(), string.end());
}
void Pickler::pushGlobal(
@@ -489,45 +492,49 @@
push<OpCode>(OpCode::TUPLE);
}
-IValue Unpickler::parse_ivalue() {
+std::vector<IValue> Unpickler::parse_ivalue_list() {
run();
TORCH_CHECK(
stack_.size() == 1,
"Unpickler expected 1 element on the stack, but found ",
stack_.size());
- return stack_[0];
+ auto value = stack_[0];
+ if (value.isGenericList()) {
+ // TODO [unpickler refactor]
+ return value.toGenericListRef().vec();
+ }
+ return value.toTuple()->elements();
}
double Unpickler::readFloat() {
AT_ASSERT(sizeof(double) == 8);
- double big_endian = read<double>();
- double little_endian;
+ AT_ASSERT(bytes_ + 8 < end_ptr_);
+ double result;
// Pickle floats are big endian, so reverse the bytes
- auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
std::reverse_copy(
- big_endian_ptr,
- big_endian_ptr + sizeof(big_endian),
- reinterpret_cast<char*>(&little_endian));
+ reinterpret_cast<const char*>(bytes_),
+ reinterpret_cast<const char*>(bytes_ + 8),
+ reinterpret_cast<char*>(&result));
- return little_endian;
+ bytes_ += 8;
+ return result;
}
void Unpickler::run() {
// Expect a PROTO opcode and protocol number at the start of blob
- auto opcode = readOpCode();
TORCH_CHECK(
- opcode == OpCode::PROTO,
+ readOpCode() == OpCode::PROTO,
"Expected PROTO opcode at the start"
- " of pickle archive, found ", int(static_cast<uint8_t>(opcode)));
+ " of pickle archive");
uint8_t protocol = read<uint8_t>();
TORCH_CHECK(
protocol == 2,
"Only Pickle protocol 2 is supported, found protocol = ",
protocol);
- while (bounds_checker_()) {
+ while (bytes_ < end_ptr_) {
OpCode opcode = readInstruction();
if (opcode == OpCode::STOP) {
return;
@@ -620,12 +627,15 @@
case OpCode::LONG1: {
// Only read LONG1s with 8 as the length
uint8_t length = read<uint8_t>();
- TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
+ AT_ASSERT(length == 8);
stack_.emplace_back(int64_t(read<int64_t>()));
} break;
case OpCode::BINUNICODE: {
uint32_t length = read<uint32_t>();
- stack_.emplace_back(readBytes(length));
+ const char* characters = reinterpret_cast<const char*>(bytes_);
+ AT_ASSERT(bytes_ + length < end_ptr_);
+ bytes_ += length;
+ stack_.emplace_back(std::string(characters, /*n=*/length));
} break;
case OpCode::BINFLOAT:
stack_.emplace_back(readFloat());
@@ -695,10 +705,6 @@
stack_.pop_back();
switch (pickler_class) {
case PicklerClass::TENSOR:
- TORCH_CHECK(
- tensor_table_,
- "Found a tensor table reference but Pickler"
- " has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
break;
case PicklerClass::INTLIST:
@@ -766,21 +772,13 @@
"Unknown opcode for unpickling at ",
reinterpret_cast<void*>(opcode),
": ",
- int(static_cast<uint8_t>(opcode)));
+ static_cast<uint8_t>(opcode));
}
return opcode;
}
-// Read a number of bytes from the input stream
-std::string Unpickler::readBytes(size_t length) {
- std::string data(length, 0);
- // This is fine since C++11 has contiguous strings
- reader_(&data[0], length);
- return data;
-}
-
-// Pop all the list items off of the stack and append them to the list at
-// the corresponding MARK
+// Pop all the list items off of the stack and append them to the list at the
+// corresponding MARK
void Unpickler::readList() {
size_t start = marks_.back();
marks_.pop_back();
@@ -831,24 +829,33 @@
// Read a newline terminated string
std::string Unpickler::readString() {
- std::stringstream ss;
+ const char* chars = reinterpret_cast<const char*>(bytes_);
+ const char* char_end_ptr = reinterpret_cast<const char*>(end_ptr_);
+ size_t n = 0;
while (true) {
- char c = read<char>();
+ char c = chars[n];
if (c == '\n') {
break;
}
- ss << c;
-
// Simple check just in case there is no terminating '\n'
TORCH_CHECK(
is_valid_python_id_char(c),
"Found character '",
- int(uint8_t(c)),
- "' in string, ",
+ uint8_t(c),
+ "' in string, "
"strings must be qualified Python identifiers");
+
+ // Increment after to exclude newline from string
+ ++n;
+ TORCH_CHECK(
+ chars + n < char_end_ptr,
+ "Unpickler overran buffer while reading a string (expected a newline)");
}
- return ss.str();
+
+ // Increment by string length + newline char
+ bytes_ += n + 1;
+ return std::string(chars, n);
}
OpCode Unpickler::readOpCode() {
diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h
index 9c05ae5..65919ba 100644
--- a/torch/csrc/jit/pickler.h
+++ b/torch/csrc/jit/pickler.h
@@ -11,9 +11,6 @@
namespace torch {
namespace jit {
-using ClassResolver =
- std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
-
// See Python's pickletools.py for a detailed description of each of these codes
enum class OpCode : char {
MARK = '(',
@@ -125,10 +122,10 @@
TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
public:
- Pickler(
- std::function<void(const char*, size_t)> writer,
- std::vector<at::Tensor>* tensor_table = nullptr)
- : writer_(writer), tensor_table_(tensor_table) {}
+ Pickler(std::vector<at::Tensor>* tensor_table = nullptr)
+ : tensor_table_(tensor_table) {}
+
+ const std::vector<char>& stack();
// Push protocol onto the stack
void protocol();
@@ -189,12 +186,9 @@
template <typename T>
void push(typename std::common_type<T>::type value) {
const char* begin = reinterpret_cast<const char*>(&value);
- writer_(begin, sizeof(T));
+ stack_.insert(stack_.end(), begin, begin + sizeof(T));
}
- // Stream to write binary data to
- std::function<void(const char*, size_t)> writer_;
-
// Stack of opcodes/data
std::vector<char> stack_;
@@ -234,29 +228,32 @@
public:
Unpickler(
- std::function<void(char*, size_t)> reader,
- std::function<bool()> bounds_checker,
+ const void* data,
+ size_t size,
const std::vector<at::Tensor>* tensor_table,
- ClassResolver class_resolver)
- : reader_(reader),
- bounds_checker_(bounds_checker),
+ std::function<c10::StrongTypePtr(const c10::QualifiedName&)>
+ class_resolver)
+ : bytes_(static_cast<const uint8_t*>(data)),
+ end_ptr_(bytes_ + size),
tensor_table_(tensor_table),
- class_resolver_(std::move(class_resolver)) {}
+ class_resolver_(class_resolver) {}
- IValue parse_ivalue();
+ std::vector<IValue> parse_ivalue_list();
private:
// No arguments ensures that a template arugment must be specified
// so that the number of bytes read / type read is explicit
template <typename T>
T read() {
+ TORCH_CHECK(
+ bytes_ + sizeof(T) <= end_ptr_,
+ "Unpickler overran buffer while reading a value");
T item;
- reader_(reinterpret_cast<char*>(&item), sizeof(item));
+ std::memcpy(&item, bytes_, sizeof(T));
+ bytes_ += sizeof(T);
return item;
}
- std::string readBytes(size_t num_bytes);
-
double readFloat();
OpCode readInstruction();
OpCode readOpCode();
@@ -265,24 +262,18 @@
void setInput(size_t memo_id);
void run();
- // Returns a pointer to the number of bytes requested. This should state-fully
- // remember how many bytes have been read
- std::function<void(char*, size_t)> reader_;
-
- // Check if the stream has gone past its size
- std::function<bool()> bounds_checker_;
-
std::vector<IValue> stack_;
-
// globals are represented on the stack as IValue integer indices
// into this list
std::vector<std::function<void(void)>> globals_;
std::vector<IValue> memo_table_;
std::vector<size_t> marks_;
+ const uint8_t* bytes_;
+ const uint8_t* end_ptr_;
const std::vector<at::Tensor>* tensor_table_;
// optionally nullptr, needs to be present for creating classes
- ClassResolver class_resolver_;
+ std::function<c10::StrongTypePtr(const c10::QualifiedName&)> class_resolver_;
IValue empty_tuple_;
};
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 2c1e466..9efcfcc 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -9,7 +9,7 @@
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/pickle.h>
+#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/print_handler.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/script/compilation_unit.h>
@@ -625,14 +625,19 @@
"aten::save(t item, str filename) -> ()",
[](Stack& stack) {
auto filename = pop(stack).toStringRef();
- auto ivalue = pop(stack);
+ auto value = pop(stack);
// Pickle the tensor
- auto data = pickle({ivalue});
+ Pickler p;
+ p.torchSaveStart();
+ p.protocol();
+ p.pushIValue(value);
+ p.stop();
+ p.torchSaveStop();
// Write file
std::fstream output(filename, std::ios::out | std::ios::binary);
- output.write(data.data(), data.size());
+ output.write(p.stack().data(), p.stack().size());
return 0;
},
aliasAnalysisFromSchema()),
diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp
index 5763ba2..f976d6f 100644
--- a/torch/csrc/jit/script/module.cpp
+++ b/torch/csrc/jit/script/module.cpp
@@ -240,7 +240,6 @@
}
return result;
}
-
std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_graph() {
auto result = lower_graph(owner().module_object(), *graph());
return std::make_pair(result.first, loadTensors(result.second));
diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h
index b24d012..11c6779 100644
--- a/torch/csrc/jit/source_range.h
+++ b/torch/csrc/jit/source_range.h
@@ -151,7 +151,7 @@
bool operator!=(const SourceRange& rhs) const {
return !(*this == rhs);
}
-
+
c10::optional<SourceRange> findSourceRangeThatGenerated() const {
if (!source_) {
return c10::nullopt;
diff --git a/torch/csrc/jit/source_range_serialization.cpp b/torch/csrc/jit/source_range_serialization.cpp
index 12a8edb..0c3041f 100644
--- a/torch/csrc/jit/source_range_serialization.cpp
+++ b/torch/csrc/jit/source_range_serialization.cpp
@@ -2,7 +2,7 @@
#include <torch/csrc/jit/source_range_serialization_impl.h>
#include <ATen/core/ivalue.h>
-#include <torch/csrc/jit/pickle.h>
+#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
@@ -83,22 +83,23 @@
}
SourceRangePickler::SourceRangePickler()
- : srs(new SourceRangeSerializer()) {}
+ : p(new Pickler()), srs(new SourceRangeSerializer()) {}
-std::vector<char> SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
- std::vector<c10::IValue> ivalues;
+void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
+ p->protocol();
+ p->startTuple();
for (const auto& range : ranges) {
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
srs->serialize(range.range)};
- ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems)));
+ p->pushIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
}
- std::vector<at::Tensor> table;
- auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
- auto result = jit::pickle(ivalue, &table);
- TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
- return result;
+ p->endTuple();
+ p->stop();
}
+const std::vector<char>& SourceRangePickler::get_data() {
+ return p->stack();
+}
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
at::DataPtr&& data,
@@ -113,9 +114,8 @@
return;
}
- auto ivalues = jit::unpickle(reinterpret_cast<const char*>(data.get()), size)
- .toTuple()
- ->elements();
+ Unpickler up(data.get(), size, nullptr, nullptr);
+ auto ivalues = up.parse_ivalue_list();
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
@@ -149,4 +149,4 @@
}
} // namespace jit
-} // namespace torch
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serialization.h b/torch/csrc/jit/source_range_serialization.h
index a45f4ca..349ef16 100644
--- a/torch/csrc/jit/source_range_serialization.h
+++ b/torch/csrc/jit/source_range_serialization.h
@@ -21,9 +21,12 @@
public:
SourceRangePickler();
- std::vector<char> pickle(const SourceRangeRecords& ranges);
+ void pickle(const SourceRangeRecords& ranges);
+
+ const std::vector<char>& get_data();
private:
+ std::shared_ptr<Pickler> p;
std::shared_ptr<SourceRangeSerializer> srs;
};
@@ -36,4 +39,4 @@
};
} // namespace jit
-} // namespace torch
+} // namespace torch
\ No newline at end of file
diff --git a/torch/script.h b/torch/script.h
index 274609d..8c8cc5f 100644
--- a/torch/script.h
+++ b/torch/script.h
@@ -4,6 +4,5 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/import.h>
-#include <torch/csrc/jit/pickle.h>
#include <ATen/ATen.h>