| #include <torch/csrc/jit/pickler.h> |
| #include <ATen/ATen.h> |
| #include <string> |
| #include <ATen/core/Dict.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| using ::c10::IValue; |
| |
| // Protocol 2 is the highest that can be decoded by Python 2 |
| // See https://docs.python.org/3/library/pickle.html#data-stream-format |
| constexpr static uint8_t PROTOCOL_VERSION = 2; |
| |
| PicklerClass getClass(const std::string& str) { |
| if (str == "build_tensor_from_id") { |
| return PicklerClass::TENSOR; |
| } else if (str == "build_intlist") { |
| return PicklerClass::INTLIST; |
| } else if (str == "build_tensorlist") { |
| return PicklerClass::TENSORLIST; |
| } else if (str == "build_doublelist") { |
| return PicklerClass::DOUBLELIST; |
| } else if (str == "build_boollist") { |
| return PicklerClass::BOOLLIST; |
| } |
| |
| // TODO [unpickler refactor] |
| if (str == "TensorID") { |
| return PicklerClass::TENSOR; |
| } else if (str == "IntList") { |
| return PicklerClass::INTLIST; |
| } |
| AT_ERROR("Unknown class name for unpickler: ", str); |
| } |
| |
| const std::string& getClassName(PicklerClass cls) { |
| static const std::string tensor_class("build_tensor_from_id\n"); |
| static const std::string intlist_class("build_intlist\n"); |
| static const std::string tensorlist_class("build_tensorlist\n"); |
| static const std::string doublelist_class("build_doublelist\n"); |
| static const std::string boollist_class("build_boollist\n"); |
| switch (cls) { |
| case PicklerClass::TENSOR: |
| return tensor_class; |
| case PicklerClass::INTLIST: |
| return intlist_class; |
| case PicklerClass::TENSORLIST: |
| return tensorlist_class; |
| case PicklerClass::DOUBLELIST: |
| return doublelist_class; |
| case PicklerClass::BOOLLIST: |
| return boollist_class; |
| default: |
| AT_ERROR("Unknown class for pickler"); |
| } |
| } |
| |
| const std::string& getModuleName() { |
| static const std::string module_name("torch.jit._pickle\n"); |
| return module_name; |
| } |
| |
| const std::vector<char>& Pickler::stack() { |
| return stack_; |
| } |
| |
| void Pickler::start() { |
| push<OpCode>(OpCode::PROTO); |
| push<uint8_t>(PROTOCOL_VERSION); |
| } |
| |
| void Pickler::startTuple() { |
| // All attributes get pushed into a tuple and their indices saved in the |
| // module def |
| push<OpCode>(OpCode::MARK); |
| } |
| |
| void Pickler::endTuple() { |
| push<OpCode>(OpCode::TUPLE); |
| } |
| |
| void Pickler::finish() { |
| push<OpCode>(OpCode::STOP); |
| |
| |
| // Add the binary data for all the tensors to be included in the same binary |
| // TODO: The pickler should be refactored to stream out to a stream directly |
| // instead of staging in the stack_ array |
| if (literal_tensors_.size() > 0) { |
| // As another pickle program in the same binary archive, add a list of |
| // keys for each tensor (see torch/serialization.py) |
| start(); |
| push<OpCode>(OpCode::MARK); |
| for (const auto& tensor : literal_tensors_) { |
| std::string key = std::to_string(getStorageKey(tensor)); |
| push<OpCode>(OpCode::BINUNICODE); |
| push<uint32_t>(key.size()); |
| pushString(key); |
| } |
| push<OpCode>(OpCode::TUPLE); |
| push<OpCode>(OpCode::STOP); |
| |
| // Now dump the tensor binary data |
| for (const auto& tensor : literal_tensors_) { |
| pushTensorData(tensor); |
| } |
| } |
| } |
| |
| void Pickler::pushTensorData(const at::Tensor& tensor) { |
| // first dump size |
| auto numel = tensor.numel(); |
| auto numel_ptr = reinterpret_cast<const char*>(&numel); |
| stack_.insert(stack_.end(), numel_ptr, numel_ptr + sizeof(numel)); |
| |
| uint64_t record_size; |
| at::Tensor storage_tensor; |
| std::tie(storage_tensor, record_size) = getWriteableTensor(tensor); |
| auto storage_byte_ptr = reinterpret_cast<uint8_t*>(storage_tensor.storage().data()); |
| stack_.insert(stack_.end(), storage_byte_ptr, storage_byte_ptr + record_size); |
| } |
| |
| void Pickler::pushMetadata() { |
| // Output data to match torch.save, see torch/serialization.py for details |
| // Magic number (0x1950a86a20f9469cfc6c) |
| start(); |
| push<OpCode>(OpCode::LONG1); |
| // LONG1 size |
| pushString("\x0a"); |
| // LONG1 data |
| pushString("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19"); |
| push<OpCode>(OpCode::STOP); |
| |
| // Protocol Version (1001) |
| start(); |
| push<OpCode>(OpCode::BININT2); |
| pushString("\xe9\x03"); |
| push<OpCode>(OpCode::STOP); |
| |
| // sys_info, this isn't actually used in de-serialization so we can leave this |
| // one empty |
| start(); |
| push<OpCode>(OpCode::EMPTY_DICT); |
| push<OpCode>(OpCode::STOP); |
| } |
| |
| void Pickler::addIValue(const IValue& ivalue) { |
| // Check if reference ivalue has been saved before |
| const void* ivalue_ptr = getPointer(ivalue); |
| if (ivalue_ptr) { |
| auto memo_entry = memo_map_.find(ivalue_ptr); |
| if (memo_entry != memo_map_.end()) { |
| // This value has already been pushed, just do a BINGET |
| pushBinGet(memo_entry->second); |
| return; |
| } |
| } |
| |
| if (ivalue.isTensor()) { |
| pushTensor(ivalue); |
| } else if (ivalue.isTuple()) { |
| pushTuple(ivalue); |
| } else if (ivalue.isDouble()) { |
| pushDouble(ivalue); |
| } else if (ivalue.isInt()) { |
| pushInt(ivalue); |
| } else if (ivalue.isBool()) { |
| if (ivalue.toBool()) { |
| push<OpCode>(OpCode::NEWTRUE); |
| } else { |
| push<OpCode>(OpCode::NEWFALSE); |
| } |
| } else if (ivalue.isString()) { |
| pushMemoizedString(ivalue); |
| } else if (ivalue.isGenericList()) { |
| pushGenericList(ivalue); |
| } else if (ivalue.isGenericDict()) { |
| pushDict(ivalue); |
| } else if (ivalue.isNone()) { |
| push<OpCode>(OpCode::NONE); |
| } else if (ivalue.isIntList()) { |
| pushSpecializedList( |
| ivalue, PicklerClass::INTLIST, [=](const IValue& ivalue) { |
| for (const auto& item : ivalue.toIntListRef()) { |
| addIValue(item); |
| } |
| }); |
| } else if (ivalue.isTensorList()) { |
| pushSpecializedList( |
| ivalue, PicklerClass::TENSORLIST, [=](const IValue& ivalue) { |
| for (const auto& item : ivalue.toTensorListRef()) { |
| addIValue(item); |
| } |
| }); |
| } else if (ivalue.isDoubleList()) { |
| pushSpecializedList( |
| ivalue, PicklerClass::DOUBLELIST, [=](const IValue& ivalue) { |
| for (const auto& item : ivalue.toDoubleListRef()) { |
| addIValue(item); |
| } |
| }); |
| } else if (ivalue.isBoolList()) { |
| pushSpecializedList( |
| ivalue, PicklerClass::BOOLLIST, [=](const IValue& ivalue) { |
| for (const auto& item : ivalue.toBoolListRef()) { |
| addIValue(bool(item)); |
| } |
| }); |
| } else { |
| AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); |
| } |
| } |
| |
| /// Returns a void* uniquely identifying this IValue's data. For non-containers, |
| /// returns nullptr. Also adds the ivalue to the Pickler's list of memoized |
| /// IValues so the pointers are guaranteed to be valid for the Pickler's |
| /// lifetime. |
| const void* Pickler::getPointer(const IValue& ivalue) { |
| if (ivalue.isGenericDict()) { |
| return ivalue.toGenericDict().get(); |
| } else if (ivalue.isGenericList()) { |
| return ivalue.toGenericList().get(); |
| } else if (ivalue.isTuple()) { |
| return ivalue.toTuple().get(); |
| } else if (ivalue.isString()) { |
| return ivalue.toString().get(); |
| } else if (ivalue.isIntList()) { |
| return ivalue.toIntList().get(); |
| } else if (ivalue.isTensorList()) { |
| return ivalue.toTensorList().get(); |
| } else if (ivalue.isDoubleList()) { |
| return ivalue.toDoubleList().get(); |
| } else if (ivalue.isBoolList()) { |
| return ivalue.toBoolList().get(); |
| } |
| |
| return nullptr; |
| } |
| |
| void Pickler::pushInt(const IValue& ivalue) { |
| auto n = ivalue.toInt(); |
| if (n >= std::numeric_limits<int8_t>::min() && |
| n <= std::numeric_limits<int8_t>::max()) { |
| push<OpCode>(OpCode::BININT1); |
| push<int8_t>(n); |
| } else if ( |
| n >= std::numeric_limits<int32_t>::min() && |
| n <= std::numeric_limits<int32_t>::max()) { |
| push<OpCode>(OpCode::BININT); |
| push<int32_t>(n); |
| } else { |
| // Push 8 byte integer |
| push<OpCode>(OpCode::LONG1); |
| push<uint8_t>(8); |
| push<int64_t>(n); |
| } |
| } |
| |
| void Pickler::pushBinGet(uint32_t memo_id) { |
| if (memo_id <= std::numeric_limits<uint8_t>::max()) { |
| push<OpCode>(OpCode::BINGET); |
| push<uint8_t>(memo_id); |
| } else { |
| // Memoized too many items, issue a LONG_BINGET instead |
| push<OpCode>(OpCode::LONG_BINGET); |
| push<uint32_t>(memo_id); |
| } |
| } |
| |
| void Pickler::pushMemoizedString(const IValue& ivalue) { |
| const auto& string = ivalue.toStringRef(); |
| |
| push<OpCode>(OpCode::BINUNICODE); |
| push<uint32_t>(string.size()); |
| pushString(string); |
| pushMemoization(ivalue); |
| } |
| |
| void Pickler::pushString(const std::string& string) { |
| stack_.insert(stack_.end(), string.begin(), string.end()); |
| } |
| |
| void Pickler::pushGlobal(const std::string& name_temp) { |
| auto memo_entry = memoized_strings_map_.find(name_temp); |
| if (memo_entry == memoized_strings_map_.end()) { |
| push<OpCode>(OpCode::GLOBAL); |
| pushString(name_temp); |
| |
| // Push BINPUT without adding anything to the memo_map_ |
| size_t memo_id = pushNextBinPut(); |
| memoized_strings_map_.insert({name_temp, memo_id}); |
| } else { |
| pushBinGet(memo_entry->second); |
| } |
| } |
| |
| void Pickler::pushTensor(const IValue& ivalue) { |
| if (tensor_table_ == nullptr) { |
| pushLiteralTensor(ivalue); |
| } else { |
| pushTensorReference(ivalue); |
| } |
| } |
| |
| void Pickler::pushLiteralTensor(const IValue& ivalue) { |
| // In contrast to tensor references, literal tensors are included in the |
| // pickle program binary blob. They are written to the file after the STOP |
| // opcode. They can't be included in the pickle program itself without a bunch |
| // of extra machinery since byte strings are limited to 4 GB. |
| // |
| // The format here is the same one used by `torch.save()`. The code for the |
| // format can be found in `torch/serialization.py`. |
| auto tensor = ivalue.toTensor(); |
| |
| // The arguments to this function are: |
| // storage, storage_offset, size, stride, requires_grad, backward_hooks |
| pushGlobal("torch._utils\n_rebuild_tensor_v2\n"); |
| push<OpCode>(OpCode::MARK); |
| |
| // Tuple for persistent_load |
| push<OpCode>(OpCode::MARK); |
| // typename |
| pushMemoizedString(std::string("storage")); |
| // data_type |
| std::stringstream data_type; |
| data_type << "torch\n" << toString(tensor.scalar_type()) << "Storage\n"; |
| pushGlobal(data_type.str()); |
| // root_key |
| pushMemoizedString(std::to_string(getStorageKey(tensor))); |
| // location |
| pushMemoizedString(std::string("cpu")); |
| // size |
| pushInt(tensor.numel()); |
| // view_metadata |
| push<OpCode>(OpCode::NONE); |
| push<OpCode>(OpCode::TUPLE); |
| push<OpCode>(OpCode::BINPERSID); |
| |
| // storage offset |
| int64_t storage_offset = 0; |
| pushInt(storage_offset); |
| |
| // size |
| push<OpCode>(OpCode::MARK); |
| for (auto size : tensor.sizes()) { |
| pushInt(size); |
| } |
| push<OpCode>(OpCode::TUPLE); |
| |
| // stride |
| push<OpCode>(OpCode::MARK); |
| for (auto stride : tensor.strides()) { |
| pushInt(stride); |
| } |
| push<OpCode>(OpCode::TUPLE); |
| |
| // requires_grad |
| addIValue(tensor.requires_grad()); |
| |
| // backward_hooks |
| pushGlobal("collections\nOrderedDict\n"); |
| push<OpCode>(OpCode::EMPTY_TUPLE); |
| // Construct the collections.OrderedDict for the backward_hooks |
| push<OpCode>(OpCode::REDUCE); |
| |
| push<OpCode>(OpCode::TUPLE); |
| |
| // Call torch._utils._rebuild_tensor_v2 |
| push<OpCode>(OpCode::REDUCE); |
| |
| // Store tensor so it can be placed into the binary after the pickle program |
| literal_tensors_.push_back(ivalue.toTensor()); |
| } |
| |
| void Pickler::pushClass(PicklerClass cls) { |
| pushGlobal(getModuleName() + getClassName(cls)); |
| } |
| |
| void Pickler::pushTensorReference(const IValue& ivalue) { |
| pushClass(PicklerClass::TENSOR); |
| tensor_table_->push_back(ivalue.toTensor()); |
| int64_t tensor_id = tensor_table_->size() - 1; |
| // Reduce arguments are spread (e.g. `*args`) before calling the global, |
| // so wrap in a tuple |
| push<OpCode>(OpCode::MARK); |
| addIValue(tensor_id); |
| push<OpCode>(OpCode::TUPLE); |
| |
| push<OpCode>(OpCode::REDUCE); |
| } |
| |
| void Pickler::pushSpecializedList( |
| const IValue& ivalue, |
| PicklerClass cls, |
| const std::function<void(const IValue&)>& item_pusher) { |
| pushClass(cls); |
| |
| // Reduce arguments are spread (e.g. `*args`) before calling the global, |
| // so wrap in a tuple |
| push<OpCode>(OpCode::MARK); |
| |
| push<OpCode>(OpCode::EMPTY_LIST); |
| // Mark list |
| push<OpCode>(OpCode::MARK); |
| |
| // Add all items |
| item_pusher(ivalue); |
| |
| // Finish list |
| push<OpCode>(OpCode::APPENDS); |
| |
| // Finish tuple |
| push<OpCode>(OpCode::TUPLE); |
| |
| // Call reduce |
| push<OpCode>(OpCode::REDUCE); |
| pushMemoization(ivalue); |
| } |
| |
| void Pickler::pushDouble(const IValue& ivalue) { |
| double value = ivalue.toDouble(); |
| AT_ASSERT(sizeof(double) == 8); |
| char* bytes = reinterpret_cast<char*>(&value); |
| |
| push<OpCode>(OpCode::BINFLOAT); |
| for (size_t i = 0; i < 8; ++i) { |
| push<uint8_t>(bytes[8 - i - 1]); |
| } |
| } |
| |
| void Pickler::pushDict(const IValue& ivalue) { |
| push<OpCode>(OpCode::EMPTY_DICT); |
| pushMemoization(ivalue); |
| |
| push<OpCode>(OpCode::MARK); |
| |
| // Sort the dict for deterministic keys |
| auto dict_items = ivalue.toGenericDict()->iterationOrder(); |
| for (const auto& pair : dict_items) { |
| addIValue(pair.first); |
| addIValue(pair.second); |
| } |
| |
| push<OpCode>(OpCode::SETITEMS); |
| } |
| |
| void Pickler::pushMemoization(const void* item) { |
| TORCH_CHECK(item != nullptr, "Pickler cannot memoize a nullptr"); |
| memo_map_[item] = pushNextBinPut(); |
| } |
| |
| size_t Pickler::pushNextBinPut() { |
| if (memo_id_ <= std::numeric_limits<uint8_t>::max()) { |
| push<OpCode>(OpCode::BINPUT); |
| push<uint8_t>(memo_id_); |
| } else { |
| // Memoized too many items, issue a LONG_BINPUT instead |
| push<OpCode>(OpCode::LONG_BINPUT); |
| push<uint32_t>(memo_id_); |
| } |
| AT_ASSERT(memo_id_ <= std::numeric_limits<uint32_t>::max()); |
| ++memo_id_; |
| return memo_id_ - 1; |
| } |
| |
| void Pickler::pushMemoization(const IValue& ivalue) { |
| auto ptr = getPointer(ivalue); |
| memoized_ivalues_.push_back(ivalue); |
| TORCH_CHECK( |
| ptr != nullptr, |
| "Pickler cannot memoize ", |
| ivalue.tagKind(), |
| " IValue ", |
| ivalue) |
| pushMemoization(ptr); |
| } |
| |
| void Pickler::pushGenericList(const IValue& ivalue) { |
| auto list = ivalue.toGenericListRef(); |
| push<OpCode>(OpCode::EMPTY_LIST); |
| pushMemoization(ivalue); |
| |
| push<OpCode>(OpCode::MARK); |
| |
| for (const auto& item : list) { |
| addIValue(item); |
| } |
| |
| push<OpCode>(OpCode::APPENDS); |
| } |
| |
| void Pickler::pushTuple(const IValue& ivalue) { |
| // TODO: Small tuple unrolling (e.g. TUPLE3) |
| push<OpCode>(OpCode::MARK); |
| auto tuple = ivalue.toTuple()->elements(); |
| |
| for (const auto& item : tuple) { |
| addIValue(item); |
| } |
| |
| push<OpCode>(OpCode::TUPLE); |
| pushMemoization(ivalue); |
| } |
| |
| std::vector<IValue> Unpickler::parse_ivalue_list() { |
| run(); |
| TORCH_CHECK( |
| stack_.size() == 1, |
| "Unpickler expected 1 element on the stack, but found ", |
| stack_.size()); |
| |
| auto value = stack_[0].ivalue(); |
| if (value.isGenericList()) { |
| // TODO [unpickler refactor] |
| return value.toGenericListRef(); |
| } |
| return value.toTuple()->elements(); |
| } |
| |
| double Unpickler::readFloat() { |
| AT_ASSERT(sizeof(double) == 8); |
| AT_ASSERT(bytes_ + 8 < end_ptr_); |
| double result; |
| |
| // Pickle floats are big endian, so reverse the bytes |
| std::reverse_copy( |
| reinterpret_cast<const char*>(bytes_), |
| reinterpret_cast<const char*>(bytes_ + 8), |
| reinterpret_cast<char*>(&result)); |
| |
| bytes_ += 8; |
| return result; |
| } |
| |
| void Unpickler::run() { |
| // Expect a PROTO opcode and protocol number at the start of blob |
| TORCH_CHECK( |
| readOpCode() == OpCode::PROTO, |
| "Expected PROTO opcode at the start" |
| " of pickle archive"); |
| uint8_t protocol = read<uint8_t>(); |
| TORCH_CHECK( |
| protocol == 2, |
| "Only Pickle protocol 2 is supported, found protocol = ", |
| protocol); |
| |
| while (bytes_ < end_ptr_) { |
| OpCode opcode = readInstruction(); |
| if (opcode == OpCode::STOP) { |
| return; |
| } |
| last_opcode_ = opcode; |
| } |
| |
| AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode"); |
| } |
| |
| OpCode Unpickler::readInstruction() { |
| auto opcode = readOpCode(); |
| switch (opcode) { |
| case OpCode::EMPTY_LIST: { |
| if (last_opcode_ == OpCode::NEWOBJ) { |
| // TODO [unpickler refactor] remove this case |
| // It's a list specialization, the enum ID of which is on the stack |
| TORCH_CHECK( |
| stack_.size() > 0, |
| "Unpickler found an empty stack when it expected a value"); |
| auto value = stack_.back().ivalue().toInt(); |
| TORCH_CHECK( |
| value >= 0 && value <= std::numeric_limits<uint8_t>::max(), |
| "Unpickler could not decode PicklerClass for ", |
| value); |
| PicklerClass cls = static_cast<PicklerClass>(uint8_t(value)); |
| if (cls == PicklerClass::INTLIST) { |
| stack_.emplace_back(std::vector<int64_t>()); |
| } |
| } else if (stack_.size() > 0 && stack_.back().pickler_class_opt()) { |
| // Check if we're in a GLOBAL opcode and if so, if it's a list |
| // specialization |
| if (stack_.back().pickler_class() == PicklerClass::INTLIST) { |
| stack_.emplace_back(std::vector<int64_t>()); |
| } else if (stack_.back().pickler_class() == PicklerClass::INTLIST) { |
| stack_.emplace_back(std::vector<int64_t>()); |
| } else if (stack_.back().pickler_class() == PicklerClass::TENSORLIST) { |
| stack_.emplace_back(std::vector<at::Tensor>()); |
| } else if (stack_.back().pickler_class() == PicklerClass::DOUBLELIST) { |
| stack_.emplace_back(std::vector<double>()); |
| } else if (stack_.back().pickler_class() == PicklerClass::BOOLLIST) { |
| stack_.emplace_back(std::vector<bool>()); |
| } else { |
| AT_ERROR("Unknown list specialization"); |
| } |
| } else { |
| stack_.emplace_back(std::vector<IValue>()); |
| } |
| } break; |
| case OpCode::EMPTY_TUPLE: { |
| stack_.emplace_back(c10::ivalue::Tuple::create({})); |
| } break; |
| case OpCode::BINPUT: { |
| size_t memo_id = read<uint8_t>(); |
| if (memo_table_.capacity() <= memo_id) { |
| memo_table_.reserve(1 + 2 * memo_id); |
| } |
| memo_table_.push_back(stack_.back()); |
| } break; |
| case OpCode::LONG_BINPUT: { |
| TORCH_CHECK( |
| std::numeric_limits<size_t>::max() >= |
| std::numeric_limits<uint32_t>::max(), |
| "Found a LONG_BINPUT opcode, but size_t on this system is " |
| "not big enough to decode it"); |
| size_t memo_id = read<uint32_t>(); |
| if (memo_table_.capacity() <= memo_id) { |
| memo_table_.reserve(1 + 2 * memo_id); |
| } |
| memo_table_.push_back(stack_.back()); |
| } break; |
| case OpCode::MARK: { |
| // Mark location of the container ivalue in the stack |
| marks_.push_back(stack_.size()); |
| } break; |
| case OpCode::NEWTRUE: { |
| stack_.emplace_back(true); |
| } break; |
| case OpCode::NEWFALSE: { |
| stack_.emplace_back(false); |
| } break; |
| case OpCode::NONE: { |
| stack_.emplace_back(IValue()); |
| } break; |
| case OpCode::BININT1: { |
| int8_t value = read<int8_t>(); |
| stack_.emplace_back(int64_t(value)); |
| } break; |
| case OpCode::BININT: { |
| int32_t value = read<int32_t>(); |
| stack_.emplace_back(int64_t(value)); |
| } break; |
| case OpCode::LONG1: { |
| // Only read LONG1s with 8 as the length |
| uint8_t length = read<uint8_t>(); |
| AT_ASSERT(length == 8); |
| stack_.emplace_back(int64_t(read<int64_t>())); |
| } break; |
| case OpCode::BINUNICODE: { |
| uint32_t length = read<uint32_t>(); |
| const char* characters = reinterpret_cast<const char*>(bytes_); |
| AT_ASSERT(bytes_ + length < end_ptr_); |
| bytes_ += length; |
| stack_.emplace_back(std::string(characters, /*n=*/length)); |
| } break; |
| case OpCode::BINFLOAT: |
| stack_.emplace_back(readFloat()); |
| break; |
| case OpCode::TUPLE: { |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| auto tuple = c10::ivalue::Tuple::create({}); |
| tuple->elements().reserve(stack_.size() - start); |
| auto start_it = stack_.begin() + start; |
| for (auto it = start_it; it != stack_.end(); ++it) { |
| tuple->elements().emplace_back(it->ivalue()); |
| } |
| stack_.erase(start_it, stack_.end()); |
| stack_.emplace_back(IValue(tuple)); |
| } break; |
| case OpCode::EMPTY_DICT: |
| stack_.emplace_back(c10::impl::make_generic_dict()); |
| break; |
| case OpCode::APPENDS: { |
| readList(); |
| } break; |
| case OpCode::SETITEMS: { |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| auto dict = stack_.at(start - 1).ivalue().toGenericDict(); |
| for (size_t i = start; i < stack_.size(); i += 2) { |
| dict->elements().insert_or_assign(stack_[i].ivalue(), stack_[i + 1].ivalue()); |
| } |
| stack_.erase(stack_.begin() + start, stack_.end()); |
| } break; |
| case OpCode::BINGET: { |
| stack_.push_back(memo_table_.at(read<uint8_t>())); |
| } break; |
| case OpCode::LONG_BINGET: { |
| stack_.push_back(memo_table_.at(read<uint32_t>())); |
| } break; |
| case OpCode::STOP: |
| break; |
| case OpCode::GLOBAL: { |
| // Module name, it's not needed for anything |
| auto module_name = readString(); |
| // TODO [unpickler refactor] __main__ isn't used by the pickler anymore |
| if (module_name == "__main__") { |
| stack_.emplace_back(static_cast<uint8_t>(getClass(readString()))); |
| } else { |
| // Push class name to stack |
| stack_.emplace_back(getClass(readString())); |
| } |
| } break; |
| case OpCode::NEWOBJ: { |
| // pop empty tuple |
| stack_.pop_back(); |
| } break; |
| case OpCode::BUILD: { |
| // TODO: [unpickler refactor] |
| auto setitem_data = stack_.back().ivalue(); |
| stack_.pop_back(); |
| |
| auto class_name = |
| static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt())); |
| stack_.pop_back(); |
| |
| switch (class_name) { |
| case PicklerClass::TENSOR: |
| stack_.emplace_back(tensor_table_->at(setitem_data.toInt())); |
| break; |
| case PicklerClass::INTLIST: |
| stack_.emplace_back(setitem_data); |
| break; |
| default: |
| AT_ERROR("Unknown pickler class id"); |
| } |
| } break; |
| case OpCode::REDUCE: { |
| // Pop reduce arg off the stack |
| auto data = stack_.back().ivalue().toTuple(); |
| stack_.pop_back(); |
| |
| // Remove GLOBAL from stack |
| auto class_name = stack_.back().pickler_class(); |
| stack_.pop_back(); |
| |
| switch (class_name) { |
| case PicklerClass::TENSOR: |
| stack_.emplace_back( |
| tensor_table_->at(data->elements().at(0).toInt())); |
| break; |
| case PicklerClass::INTLIST: |
| stack_.emplace_back(data->elements().at(0).toIntListRef()); |
| break; |
| case PicklerClass::TENSORLIST: |
| stack_.emplace_back(data->elements().at(0).toTensorListRef()); |
| break; |
| case PicklerClass::DOUBLELIST: |
| stack_.emplace_back(data->elements().at(0).toDoubleListRef()); |
| break; |
| case PicklerClass::BOOLLIST: |
| stack_.emplace_back(data->elements().at(0).toBoolListRef()); |
| break; |
| default: |
| AT_ERROR("Unknown pickler class id"); |
| } |
| } break; |
| default: |
| AT_ERROR( |
| "Unknown opcode for unpickling at ", |
| reinterpret_cast<void*>(opcode), |
| ": ", |
| static_cast<uint8_t>(opcode)); |
| } |
| return opcode; |
| } |
| |
| // Pop all the list items off of the stack and append them to the list at the |
| // corresponding MARK |
| void Unpickler::readList() { |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| auto list_ivalue = stack_.at(start - 1).ivalue(); |
| auto num_elements = stack_.size() - start; |
| auto elements = at::ArrayRef<StackItem>(stack_).slice(start); |
| if (list_ivalue.isIntList()) { |
| auto& list = list_ivalue.toIntList()->elements(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.ivalue().toInt()); |
| } |
| } else if (list_ivalue.isTensorList()) { |
| auto& list = list_ivalue.toTensorList()->elements(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.ivalue().toTensor()); |
| } |
| } else if (list_ivalue.isDoubleList()) { |
| auto& list = list_ivalue.toDoubleList()->elements(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.ivalue().toDouble()); |
| } |
| } else if (list_ivalue.isBoolList()) { |
| auto& list = list_ivalue.toBoolList()->elements(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.push_back(elem.ivalue().toBool()); |
| } |
| } else if (list_ivalue.isGenericList()) { |
| auto& list = list_ivalue.toGenericList()->elements(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.ivalue()); |
| } |
| } else { |
| AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind()); |
| } |
| |
| stack_.erase(stack_.begin() + start, stack_.end()); |
| } |
| |
| inline bool is_valid_python_id_char(char c) { |
| return c == '_' || c == '.' || (c >= '0' && c <= '9') || |
| (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); |
| } |
| |
| // Read a newline terminated string |
| std::string Unpickler::readString() { |
| const char* chars = reinterpret_cast<const char*>(bytes_); |
| const char* char_end_ptr = reinterpret_cast<const char*>(end_ptr_); |
| size_t n = 0; |
| while (true) { |
| char c = chars[n]; |
| if (c == '\n') { |
| break; |
| } |
| |
| // Simple check just in case there is no terminating '\n' |
| TORCH_CHECK( |
| is_valid_python_id_char(c), |
| "Found character '", |
| uint8_t(c), |
| "' in string, " |
| "strings must be qualified Python identifiers"); |
| |
| // Increment after to exclude newline from string |
| ++n; |
| TORCH_CHECK( |
| chars + n < char_end_ptr, |
| "Unpickler overran buffer while reading a string (expected a newline)"); |
| } |
| |
| // Increment by string length + newline char |
| bytes_ += n + 1; |
| return std::string(chars, n); |
| } |
| |
| OpCode Unpickler::readOpCode() { |
| return static_cast<OpCode>(read<uint8_t>()); |
| } |
| |
| std::pair<at::Tensor, uint64_t> getWriteableTensor(const at::Tensor& tensor) { |
| at::Tensor storage_tensor = tensor; |
| uint64_t record_size = tensor.element_size() * tensor.storage().size(); |
| // TODO HIP support |
| if (tensor.storage().device_type() == at::DeviceType::CUDA) { |
| // NB: This new tensor is created to support cuda tensors. |
| // Storages can be mutated when converting tensors from cuda to cpu, |
| // and we need a cpu tensor to copy data from. |
| storage_tensor = at::empty({0}, tensor.options()) |
| .set_( |
| tensor.storage(), |
| /* storage_offset = */ 0, |
| /* size = */ |
| {static_cast<int64_t>(tensor.storage().size())}, |
| /* stride = */ {1}) |
| .cpu(); |
| TORCH_CHECK( |
| storage_tensor.element_size() * storage_tensor.storage().size() == |
| record_size, |
| "Storage tensor size did not match record size"); |
| } |
| |
| return std::make_pair(storage_tensor, record_size); |
| } |
| |
| uint64_t getStorageKey(const at::Tensor& tensor) { |
| at::StorageImpl* storage_key = tensor.storage().unsafeGetStorageImpl(); |
| return reinterpret_cast<intptr_t>(storage_key); |
| } |
| |
| } // namespace jit |
| } // namespace torch |