| #include <ATen/ATen.h> |
| #include <ATen/core/Dict.h> |
| #include <torch/csrc/jit/function.h> |
| #include <torch/csrc/jit/pickler.h> |
| #include <aten/src/ATen/quantized/Quantizer.h> |
| #include <string> |
| |
| namespace torch { |
| namespace jit { |
| |
| using ::c10::IValue; |
| |
| // Protocol 2 is the highest that can be decoded by Python 2 |
| // See https://docs.python.org/3/library/pickle.html#data-stream-format |
| constexpr static uint8_t PROTOCOL_VERSION = 2; |
| |
| Pickler::~Pickler() { |
| flush(); |
| } |
| |
| void Pickler::protocol() { |
| push<PickleOpCode>(PickleOpCode::PROTO); |
| push<uint8_t>(PROTOCOL_VERSION); |
| } |
| |
| void Pickler::startTuple() { |
| // All attributes get pushed into a tuple and their indices saved in the |
| // module def |
| push<PickleOpCode>(PickleOpCode::MARK); |
| } |
| |
| void Pickler::endTuple() { |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| } |
| |
| void Pickler::stop() { |
| push<PickleOpCode>(PickleOpCode::STOP); |
| flush(); |
| } |
| |
| // unmemoized version called by pushIValue |
| void Pickler::pushIValueImpl(const IValue& ivalue) { |
| if (ivalue.isTensor()) { |
| pushTensor(ivalue); |
| } else if (ivalue.isTuple()) { |
| pushTuple(ivalue); |
| } else if (ivalue.isDouble()) { |
| pushDouble(ivalue.toDouble()); |
| } else if (ivalue.isInt()) { |
| pushInt(ivalue.toInt()); |
| } else if (ivalue.isBool()) { |
| pushBool(ivalue.toBool()); |
| } else if (ivalue.isString()) { |
| pushString(ivalue.toStringRef()); |
| } else if (ivalue.isGenericDict()) { |
| pushDict(ivalue); |
| } else if (ivalue.isNone()) { |
| push<PickleOpCode>(PickleOpCode::NONE); |
| } else if (ivalue.isIntList()) { |
| pushSpecializedList( |
| ivalue, "build_intlist", [=](const IValue& ivalue) { |
| for (const int64_t item : ivalue.toIntListRef()) { |
| pushInt(item); |
| } |
| }); |
| } else if (ivalue.isTensorList()) { |
| pushSpecializedList( |
| ivalue, "build_tensorlist", [=](const IValue& ivalue) { |
| for (const at::Tensor& item : ivalue.toTensorListRef()) { |
| pushIValue(item); |
| } |
| }); |
| } else if (ivalue.isDoubleList()) { |
| pushSpecializedList( |
| ivalue, "build_doublelist", [=](const IValue& ivalue) { |
| for (double item : ivalue.toDoubleListRef()) { |
| pushDouble(item); |
| } |
| }); |
| } else if (ivalue.isBoolList()) { |
| pushSpecializedList( |
| ivalue, "build_boollist", [=](const IValue& ivalue) { |
| for (bool item : ivalue.toBoolList()) { |
| pushBool(item); |
| } |
| }); |
| // note: isGenericList must be after isIntList and friends because |
| // isGenericList is true for all lists. |
| } else if (ivalue.isGenericList()) { |
| pushGenericList(ivalue); |
| } else if (ivalue.isObject()) { |
| auto obj = ivalue.toObject(); |
| auto type = obj->type(); |
| if (memorized_class_types_ != nullptr) { |
| // Memorize every class type the Pickler encountered |
| // This is used to make sure we capture all the run-time types |
| // and serialize them properly for class/interface polymorphism |
| memorized_class_types_->emplace_back(type); |
| } |
| pushGlobal(type->name()->prefix(), type->name()->name()); |
| push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE); |
| push<PickleOpCode>(PickleOpCode::NEWOBJ); |
| if (checkHasValidSetGetState(type)) { |
| Function* getstate = type->getMethod("__getstate__"); |
| pushIValue((*getstate)({obj})); |
| } else { |
| push<PickleOpCode>(PickleOpCode::EMPTY_DICT); |
| push<PickleOpCode>(PickleOpCode::MARK); |
| for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { |
| pushString(type->getAttributeName(i)); |
| pushIValue(obj->getSlot(i)); |
| } |
| push<PickleOpCode>(PickleOpCode::SETITEMS); |
| } |
| push<PickleOpCode>(PickleOpCode::BUILD); |
| } else if (ivalue.isDevice()) { |
| pushDevice(ivalue); |
| } else { |
| AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); |
| } |
| } |
| |
| void Pickler::pushDevice(const IValue& ivalue) { |
| auto device = ivalue.toDevice(); |
| auto it = memoized_devices_map_.find(device.str()); |
| if (it == memoized_devices_map_.end()) { |
| pushGlobal("torch", "device"); |
| pushString(ivalue.toDevice().str()); |
| push<PickleOpCode>(PickleOpCode::TUPLE1); |
| push<PickleOpCode>(PickleOpCode::REDUCE); |
| memoized_devices_map_[device.str()] = pushNextBinPut(); |
| } else { |
| pushBinGet(it->second); |
| } |
| } |
| |
| void Pickler::pushIValue(const IValue& ivalue) { |
| bool shouldMemoizeByPointer = |
| ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1; |
| |
| // Mutable ivalues are memoized by pointer equality, which we handle at this outer |
| // granularity. Immutable ivalues are memoized by value equality which is handled in |
| // the type-specific handlers inside pushIValueImpl. |
| if (shouldMemoizeByPointer) { |
| const void* ptr = ivalue.internalToPointer(); |
| TORCH_CHECK( |
| ptr != nullptr, |
| "Pickler cannot memoize ", |
| ivalue.tagKind(), |
| " IValue ", |
| ivalue); |
| auto memo_entry = memoized_ivalue_map_.find(ptr); |
| if (memo_entry != memoized_ivalue_map_.end()) { |
| // This value has already been pushed, just do a BINGET |
| pushBinGet(memo_entry->second); |
| return; |
| } |
| |
| pushIValueImpl(ivalue); |
| |
| memoized_ivalues_.push_back(ivalue); |
| memoized_ivalue_map_[ivalue.internalToPointer()] = pushNextBinPut(); |
| } else { |
| pushIValueImpl(ivalue); |
| } |
| } |
| |
| void Pickler::pushInt(int64_t n) { |
| if (n >= std::numeric_limits<uint8_t>::min() && |
| n <= std::numeric_limits<uint8_t>::max()) { |
| push<PickleOpCode>(PickleOpCode::BININT1); |
| push<uint8_t>(n); |
| } else if ( |
| n >= std::numeric_limits<uint16_t>::min() && |
| n <= std::numeric_limits<uint16_t>::max()) { |
| push<PickleOpCode>(PickleOpCode::BININT2); |
| push<uint16_t>(n); |
| } else if ( |
| n >= std::numeric_limits<int32_t>::min() && |
| n <= std::numeric_limits<int32_t>::max()) { |
| push<PickleOpCode>(PickleOpCode::BININT); |
| push<int32_t>(n); |
| } else { |
| // Push 8 byte integer |
| push<PickleOpCode>(PickleOpCode::LONG1); |
| push<uint8_t>(8); |
| push<int64_t>(n); |
| } |
| } |
| |
| void Pickler::pushBool(bool value) { |
| push<PickleOpCode>(value ? PickleOpCode::NEWTRUE : PickleOpCode::NEWFALSE); |
| } |
| |
| void Pickler::pushBinGet(uint32_t memo_id) { |
| if (memo_id <= std::numeric_limits<uint8_t>::max()) { |
| push<PickleOpCode>(PickleOpCode::BINGET); |
| push<uint8_t>(memo_id); |
| } else { |
| // Memoized too many items, issue a LONG_BINGET instead |
| push<PickleOpCode>(PickleOpCode::LONG_BINGET); |
| push<uint32_t>(memo_id); |
| } |
| } |
| |
| // unmemoized encoding of a string |
| void Pickler::pushStringImpl(const std::string& string) { |
| push<PickleOpCode>(PickleOpCode::BINUNICODE); |
| push<uint32_t>(string.size()); |
| pushBytes(string); |
| } |
| |
| void Pickler::pushString(const std::string& string) { |
| auto it = memoized_strings_map_.find(string); |
| if (it == memoized_strings_map_.end()) { |
| pushStringImpl(string); |
| memoized_strings_map_[string] = pushNextBinPut(); |
| } else { |
| pushBinGet(it->second); |
| } |
| } |
| |
| void Pickler::pushStorageOfTensor(const at::Tensor& tensor) { |
| const at::Storage& storage = tensor.storage(); |
| void* addr = storage.unsafeGetStorageImpl(); |
| auto it = memoized_storage_map_.find(addr); |
| if (it != memoized_storage_map_.end()) { |
| pushBinGet(it->second); |
| return; |
| } |
| |
| // Tuple for persistent_load |
| push<PickleOpCode>(PickleOpCode::MARK); |
| // typename |
| pushString("storage"); |
| // data_type |
| std::string data_type = |
| std::string(toString(tensor.scalar_type())).append("Storage"); |
| pushGlobal("torch", data_type); |
| // root_key |
| pushString(c10::to_string(tensor_data_.size())); |
| // location |
| pushString(tensor.device().str()); |
| // size |
| pushInt(tensor.storage().size()); |
| |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| push<PickleOpCode>(PickleOpCode::BINPERSID); |
| |
| // TODO: Skip this if not writing tensors |
| memoized_storage_map_[addr] = pushNextBinPut(); |
| tensor_data_.push_back(getWriteableTensorData(tensor)); |
| } |
| |
| void Pickler::pushBytes(const std::string& string) { |
| static const size_t kSmallStr = 32; |
| if (string.size() <= kSmallStr && |
| bufferPos_ + string.size() <= buffer_.size()) { |
| // Small string that fits: buffer the data. |
| memcpy(buffer_.data() + bufferPos_, string.data(), string.size()); |
| bufferPos_ += string.size(); |
| } else { |
| // Otherwise, first flush, then write directly. |
| flush(); |
| writer_(string.data(), string.size()); |
| } |
| } |
| |
| void Pickler::pushGlobal( |
| const std::string& module_name, |
| const std::string& class_name) { |
| std::string key; |
| key.reserve(module_name.size() + class_name.size() + 2); |
| key.append(module_name).append("\n").append(class_name).append("\n"); |
| auto memo_entry = memoized_globals_map_.find(key); |
| if (memo_entry == memoized_globals_map_.end()) { |
| push<PickleOpCode>(PickleOpCode::GLOBAL); |
| pushBytes(key); |
| // Push BINPUT without adding anything to the memoized_ivalues_ |
| size_t memo_id = pushNextBinPut(); |
| memoized_globals_map_.insert({key, memo_id}); |
| } else { |
| pushBinGet(memo_entry->second); |
| } |
| } |
| |
| void Pickler::pushTensor(const IValue& ivalue) { |
| if (tensor_table_ == nullptr) { |
| pushLiteralTensor(ivalue); |
| } else { |
| pushTensorReference(ivalue); |
| } |
| } |
| |
| void Pickler::pushLiteralTensor(const IValue& ivalue) { |
| // In contrast to tensor references, literal tensors are included in the |
| // pickle program binary blob. They are written to the file after the STOP |
| // opcode. They can't be included in the pickle program itself without a bunch |
| // of extra machinery since byte strings are limited to 4 GB. |
| // |
| // The format here is the same one used by `torch.save()`. The code for the |
| // format can be found in `torch/serialization.py`. |
| auto tensor = ivalue.toTensor(); |
| bool quantized = tensor.is_quantized(); |
| // The arguments to this function are: |
| // storage, storage_offset, size, stride, requires_grad, backward_hooks |
| pushGlobal( |
| "torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2"); |
| |
| push<PickleOpCode>(PickleOpCode::MARK); |
| |
| pushStorageOfTensor(tensor); |
| |
| // storage offset |
| pushInt(tensor.storage_offset()); |
| |
| // size |
| push<PickleOpCode>(PickleOpCode::MARK); |
| for (auto size : tensor.sizes()) { |
| pushInt(size); |
| } |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| |
| // stride |
| push<PickleOpCode>(PickleOpCode::MARK); |
| for (auto stride : tensor.strides()) { |
| pushInt(stride); |
| } |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| |
| if (quantized) { |
| push<PickleOpCode>(PickleOpCode::MARK); |
| pushGlobal("torch", toString(tensor.qscheme())); |
| // tuple of (qscheme, scale, zp) or (qscheme, scales, zps, axis) |
| switch (tensor.qscheme()) { |
| case at::kPerTensorAffine: |
| pushDouble(tensor.q_scale()); |
| pushInt(tensor.q_zero_point()); |
| break; |
| case at::kPerChannelAffine: { |
| const auto* quantizer = static_cast<at::PerChannelAffineQuantizer*>( |
| tensor.quantizer().get()); |
| pushIValue(c10::List<double>(quantizer->scales())); |
| pushIValue(c10::List<int64_t>(quantizer->zero_points())); |
| pushInt(quantizer->axis()); |
| } break; |
| default: |
| TORCH_CHECK( |
| false, |
| "Unsupported tensor quantization type in serialization ", |
| toString(tensor.qscheme())); |
| break; |
| } |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| } |
| |
| // requires_grad |
| pushIValue(tensor.requires_grad()); |
| |
| // backward_hooks |
| pushGlobal("collections", "OrderedDict"); |
| push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE); |
| // Construct the collections.OrderedDict for the backward_hooks |
| push<PickleOpCode>(PickleOpCode::REDUCE); |
| |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| |
| // Call torch._utils._rebuild_tensor_v2 |
| push<PickleOpCode>(PickleOpCode::REDUCE); |
| } |
| |
| void Pickler::pushSpecializedList( |
| const IValue& ivalue, |
| const char* list_name, |
| const std::function<void(const IValue&)>& item_pusher) { |
| pushGlobal("torch.jit._pickle", list_name); |
| |
| // Reduce arguments are spread (e.g. `*args`) before calling the global, |
| // so wrap in a tuple |
| push<PickleOpCode>(PickleOpCode::MARK); |
| |
| push<PickleOpCode>(PickleOpCode::EMPTY_LIST); |
| // Mark list |
| push<PickleOpCode>(PickleOpCode::MARK); |
| |
| // Add all items |
| item_pusher(ivalue); |
| |
| // Finish list |
| push<PickleOpCode>(PickleOpCode::APPENDS); |
| |
| // Finish tuple |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| |
| // Call reduce |
| push<PickleOpCode>(PickleOpCode::REDUCE); |
| } |
| |
| static inline double swapDouble(double value) { |
| const char* bytes = reinterpret_cast<const char*>(&value); |
| double flipped; |
| char* out_bytes = reinterpret_cast<char*>(&flipped); |
| for (size_t i = 0; i < sizeof(double); ++i) { |
| out_bytes[i] = bytes[sizeof(double) - i - 1]; |
| } |
| return *reinterpret_cast<double*>(out_bytes); |
| } |
| |
| void Pickler::pushDouble(double value) { |
| push<PickleOpCode>(PickleOpCode::BINFLOAT); |
| // Python pickle format is big endian, swap. |
| push<double>(swapDouble(value)); |
| } |
| |
| void Pickler::pushLong(const std::string& data) { |
| uint64_t size = data.size(); |
| |
| TORCH_INTERNAL_ASSERT( |
| size <= std::numeric_limits<uint8_t>::max(), |
| "Cannot pickle a long larger than 255 bytes"); |
| push<PickleOpCode>(PickleOpCode::LONG1); |
| push<uint8_t>(size); |
| pushBytes(data); |
| } |
| |
| void Pickler::pushTensorReference(const IValue& ivalue) { |
| pushGlobal("torch.jit._pickle", "build_tensor_from_id"); |
| tensor_table_->push_back(ivalue.toTensor()); |
| int64_t tensor_id = tensor_table_->size() - 1; |
| // Reduce arguments are spread (e.g. `*args`) before calling the global, |
| // so wrap in a tuple |
| push<PickleOpCode>(PickleOpCode::MARK); |
| pushIValue(tensor_id); |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| |
| push<PickleOpCode>(PickleOpCode::REDUCE); |
| } |
| |
| void Pickler::pushEmptyDict() { |
| push<PickleOpCode>(PickleOpCode::EMPTY_DICT); |
| } |
| void Pickler::pushDict(const IValue& ivalue) { |
| pushEmptyDict(); |
| auto dict_items = iterationOrder(ivalue.toGenericDict()); |
| if (dict_items.size() == 0) { |
| return; |
| } |
| |
| push<PickleOpCode>(PickleOpCode::MARK); |
| |
| // Sort the dict for deterministic keys |
| for (const auto& pair : dict_items) { |
| pushIValue(pair.first); |
| pushIValue(pair.second); |
| } |
| |
| push<PickleOpCode>(PickleOpCode::SETITEMS); |
| } |
| |
| size_t Pickler::pushNextBinPut() { |
| if (memo_id_ <= std::numeric_limits<uint8_t>::max()) { |
| push<PickleOpCode>(PickleOpCode::BINPUT); |
| push<uint8_t>(memo_id_); |
| } else { |
| // Memoized too many items, issue a LONG_BINPUT instead |
| push<PickleOpCode>(PickleOpCode::LONG_BINPUT); |
| push<uint32_t>(memo_id_); |
| } |
| AT_ASSERT(memo_id_ <= std::numeric_limits<uint32_t>::max()); |
| ++memo_id_; |
| return memo_id_ - 1; |
| } |
| |
| void Pickler::pushGenericList(const IValue& ivalue) { |
| auto list = ivalue.toGenericListRef(); |
| push<PickleOpCode>(PickleOpCode::EMPTY_LIST); |
| |
| push<PickleOpCode>(PickleOpCode::MARK); |
| |
| for (const IValue& item : list) { |
| pushIValue(item); |
| } |
| |
| push<PickleOpCode>(PickleOpCode::APPENDS); |
| } |
| |
| void Pickler::pushTuple(const IValue& ivalue) { |
| auto tuple = ivalue.toTuple(); |
| auto tuple_size = tuple->elements().size(); |
| |
| switch (tuple_size) { |
| case 0: { |
| push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE); |
| } break; |
| case 1: { |
| pushIValue(tuple->elements()[0]); |
| push<PickleOpCode>(PickleOpCode::TUPLE1); |
| } break; |
| case 2: { |
| pushIValue(tuple->elements()[0]); |
| pushIValue(tuple->elements()[1]); |
| push<PickleOpCode>(PickleOpCode::TUPLE2); |
| } break; |
| case 3: { |
| pushIValue(tuple->elements()[0]); |
| pushIValue(tuple->elements()[1]); |
| pushIValue(tuple->elements()[2]); |
| push<PickleOpCode>(PickleOpCode::TUPLE3); |
| } break; |
| default: { |
| push<PickleOpCode>(PickleOpCode::MARK); |
| for (const IValue& item : tuple->elements()) { |
| pushIValue(item); |
| } |
| push<PickleOpCode>(PickleOpCode::TUPLE); |
| } break; |
| } |
| } |
| |
| WriteableTensorData getWriteableTensorData(const at::Tensor& tensor) { |
| WriteableTensorData result; |
| result.tensor_ = tensor; |
| result.size_ = tensor.element_size() * tensor.storage().size(); |
| // TODO HIP support |
| if (tensor.storage().device_type() == at::DeviceType::CUDA) { |
| // NB: This new tensor is created to support cuda tensors. |
| // Storages can be mutated when converting tensors from cuda to cpu, |
| // and we need a cpu tensor to copy data from. |
| result.tensor_ = at::empty({0}, tensor.options()) |
| .set_( |
| tensor.storage(), |
| /* storage_offset = */ 0, |
| /* size = */ |
| {static_cast<int64_t>(tensor.storage().size())}, |
| /* stride = */ {1}) |
| .cpu(); |
| TORCH_CHECK( |
| result.tensor_.element_size() * result.tensor_.storage().size() == |
| result.size_, |
| "Storage tensor size did not match record size"); |
| } |
| return result; |
| } |
| |
| bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) { |
| // Check that the schemas for __getstate__ and __setstate__ are correct |
| auto getstate = cls->getMethod("__getstate__"); |
| if (getstate == nullptr) { |
| return false; |
| } |
| auto get_schema = getstate->getSchema(); |
| |
| // Check __getstate__ |
| // __getstate__ is expected to be (self) -> T |
| TORCH_CHECK( |
| get_schema.arguments().size() == 1, |
| "'__getstate__' must have 'self' as its only argument, but found ", |
| get_schema.arguments().size(), |
| " arguments"); |
| TORCH_CHECK( |
| get_schema.returns().size() == 1, |
| "'__getstate__' must return 1 value, but found ", |
| get_schema.returns().size()); |
| |
| // Check __setstate__ if the method exists |
| // __setstate__ is expected to be (self, T) -> None |
| auto setstate = cls->getMethod("__setstate__"); |
| if (!setstate) { |
| return false; |
| } |
| auto set_schema = setstate->getSchema(); |
| |
| TORCH_CHECK( |
| set_schema.arguments().size() == 2, |
| "'__setstate__' must have 'self' and the state as its " |
| "only arguments, but found ", |
| set_schema.arguments().size(), |
| " arguments"); |
| TORCH_CHECK( |
| set_schema.returns().size() == 1, |
| "'__setstate__' must return None, but found ", |
| set_schema.returns().size(), |
| " return values"); |
| TORCH_CHECK( |
| set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()), |
| "'__setstate__' must return None, but found value of type", |
| set_schema.returns().at(0).type()->python_str()); |
| |
| // Check that the return type of __getstate__ matches the input to |
| // __setstate__ |
| auto get_type = get_schema.returns().at(0).type(); |
| auto set_type = set_schema.arguments().at(1).type(); |
| |
| TORCH_CHECK( |
| get_type->isSubtypeOf(set_type), |
| "'__getstate__'s return type (", |
| get_type->python_str(), |
| ") does not match '__setstate__'s argument type (", |
| set_type->python_str(), |
| ")"); |
| |
| return true; |
| } |
| |
| } // namespace jit |
| } // namespace torch |