blob: d8d49807dcbf956ae7dfe205e47ac4086fc04893 [file] [log] [blame]
#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/import_export_helpers.h>
#ifndef C10_MOBILE
#include <torch/csrc/jit/import_legacy.h>
#endif
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/unpickler.h>
#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"
#include <ATen/ATen.h>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
void postSetStateValidate(const IValue& v) {
auto obj = v.toObject();
const auto& objType = obj->type();
for (size_t i = 0; i < objType->numAttributes(); i++) {
const auto& attrType = objType->getAttribute(i);
const auto& attrName = objType->getAttributeName(i);
const auto& slot = obj->getSlot(i);
// const auto attrType = objType->getAttribute(i);
// Verify that all the non-optional attributes have been initialized
// TODO: Issue #20497
if (attrType->kind() != TypeKind::OptionalType) {
TORCH_CHECK(
!slot.isNone(),
"The field '",
attrName,
"' was left unitialized after __setstate__, but expected a ",
"value of type '",
attrType->python_str(),
"'");
}
}
}
namespace {
// This is a deserializer class which loads script modules from pt files.
// Content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h.
// The module is saved in pickle. readArchive() is called to parse and construct
// the constant table and the script module.
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(
std::shared_ptr<script::CompilationUnit> cu,
std::unique_ptr<PyTorchStreamReader> reader)
: compilation_unit_(cu),
reader_(std::move(reader)),
source_importer_(
compilation_unit_,
&constants_table_,
[this](const std::string& qualifier) {
return findSourceInArchiveFromQualifier(
*reader_, export_prefix_, qualifier);
}) {}
script::Module deserialize(
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files);
private:
IValue readArchive(const std::string& archive_name);
std::shared_ptr<script::CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<at::Tensor> constants_table_;
script::SourceImporter source_importer_;
std::string export_prefix_ = "code/";
};
IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
std::stringstream picklename;
picklename << archive_name << ".pkl";
at::DataPtr pickle_ptr;
size_t pickle_size;
std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str());
size_t bytes_read = 0;
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
auto reader = [&](char* buffer, size_t len) -> size_t {
if (bytes_read >= pickle_size) {
return 0;
}
len = std::min(pickle_size - bytes_read, len);
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
return len;
};
auto class_resolver = [&](const c10::QualifiedName& qn) {
auto cls = source_importer_.loadNamedType(qn)->expect<ClassType>();
return c10::StrongTypePtr(compilation_unit_, std::move(cls));
};
// Decouple how to get obj from type. In this file it's dependent on
// Method.run() and graph executor, etc.
// For bytecode import we need to decouple these dependencies.
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
auto cls = type.type_->expect<at::ClassType>();
size_t n = cls->numAttributes();
if (checkHasValidSetGetState(type.type_)) {
auto obj = c10::ivalue::Object::create(type, n);
// XXX: Do not optimize __setstate__, so that we don't try to
// specialize the class before it is initialized.
setGraphExecutorOptimize(false);
(*type.type_->getMethod("__setstate__"))({obj, input});
setGraphExecutorOptimize(true);
postSetStateValidate(obj);
return obj;
} else {
auto dict = std::move(input).toGenericDict();
auto obj = c10::ivalue::Object::create(type, n);
for (size_t i = 0; i < n; ++i) {
obj->setSlot(i, dict.at(cls->getAttributeName(i)));
}
return obj;
}
};
auto read_record = [&](const std::string& name) {
std::stringstream ss;
ss << archive_name << "/" << name;
return std::get<0>(reader_->getRecord(ss.str()));
};
Unpickler unpickler(
reader, std::move(class_resolver), std::move(obj_loader),
std::move(read_record), device_);
return unpickler.parse_ivalue();
}
script::Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
C10_LOG_API_USAGE_ONCE("torch.script.load");
device_ = device;
// Load extra files.
for (const auto& kv : extra_files) {
const std::string& key = "extra/" + kv.first;
if (reader_->hasRecord(key)) {
at::DataPtr meta_ptr;
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
extra_files[kv.first] =
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
}
}
if (reader_->hasRecord("model.json")) {
#ifndef C10_MOBILE
return torch::jit::LEGACY_deserialize(
compilation_unit_, std::move(reader_), device_);
#else
AT_ERROR("Legacy model format is not supported on mobile.");
#endif
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toTensor());
}
return script::Module(readArchive("data").toObject());
}
} // namespace
script::Module import_ir_module(
std::shared_ptr<script::CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
script::Module import_ir_module(
std::shared_ptr<script::CompilationUnit> cu,
const std::string& filename,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
script::Module import_ir_module(
std::shared_ptr<script::CompilationUnit> cu,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
script::Module load(
std::istream& in,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
std::unique_ptr<IStreamAdapter> rai =
caffe2::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device, extra_files);
return module;
}
script::Module load(
const std::string& filename,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device, extra_files);
return module;
}
script::Module load(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
script::ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
auto cu = std::make_shared<script::CompilationUnit>();
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
} // namespace jit
} // namespace torch