blob: e2f173fa028c385f19100e01ff6e2239a2f8ab48 [file] [log] [blame]
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
void pickle(
std::function<void(const char*, size_t)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
Pickler pickler(std::move(writer), tensor_table);
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStart();
}
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStop();
}
}
std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
std::vector<char> data;
pickle(
[&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
},
ivalue,
tensor_table);
return data;
}
IValue unpickle(
std::function<void(char*, size_t)> reader,
std::function<bool()> bounds_checker,
std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
Unpickler unpickler(
std::move(reader),
std::move(bounds_checker),
tensor_table,
std::move(class_resolver));
return unpickler.parse_ivalue();
}
IValue unpickle(
const char* data,
size_t size,
std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) {
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
},
[&]() {
return bytes_read < size;
},
tensor_table,
std::move(class_resolver));
}
} // namespace jit
} // namespace torch