blob: 5a1e60e0673bdb5db559baa39cf87f0f9debd5c2 [file] [log] [blame]
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/import_method.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"
#include <ATen/ATen.h>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::ReadAdapterInterface;
namespace {
// this is a deserializer class which loads script modules from pt files. the
// content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h. all the records except the last
// one are tensor data, and the last record is a serialized ModelProto, defined
// in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
// model, and it is serialized as json.
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(const std::string& filename);
ScriptModuleDeserializer(std::istream* is);
explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
void deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files);
private:
at::Tensor loadTensor(
const torch::TensorDef& tensor_proto,
std::unordered_map<std::string, at::Storage>& storageMap);
void convertModule(const torch::ModuleDef& module_def);
void loadTensorTable(torch::ModelDef* model_def);
caffe2::serialize::PyTorchStreamReader reader_;
// this is a hack to make sure the script module created in C++ is the
// same as created in Python
ModuleLookup moduleLookup_;
c10::optional<at::Device> device_;
std::vector<std::string> moduleStack_;
std::vector<at::Tensor> tensor_table_;
};
ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
: reader_(filename.c_str()) {
// TODO appropriate support for mmap, right now still use stream reader
}
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: reader_(is) {}
ScriptModuleDeserializer::ScriptModuleDeserializer(
std::unique_ptr<ReadAdapterInterface> rai)
: reader_(std::move(rai)) {}
void ScriptModuleDeserializer::deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
torch::ModelDef model_def;
at::DataPtr data_ptr;
size_t data_size;
std::tie(data_ptr, data_size) = reader_.getRecord("model.json");
// NB: cannot use JsonStringToMessage, since fbcode's protobuf is too old
// be consistent with JsonStringToMessage
std::string url_prefix = "type.googleapis.com";
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
::google::protobuf::util::NewTypeResolverForDescriptorPool(
url_prefix, model_def.GetDescriptor()->file()->pool()));
std::string json_string = std::string(
static_cast<char*>(data_ptr.get()),
static_cast<char*>(data_ptr.get()) + data_size);
std::string binary_string;
auto convert_result = ::google::protobuf::util::JsonToBinaryString(
resolver.get(),
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
json_string,
&binary_string);
if (!convert_result.ok()) {
std::stringstream ss;
ss << convert_result;
AT_ERROR(ss.str());
}
AT_ASSERTM(
model_def.ParseFromString(binary_string),
"JSON transcoder produced invalid protobuf output.");
moduleLookup_ = module_lookup;
device_ = device;
const auto& module_def = model_def.main_module();
// Load extra files.
for (const auto& kv : extra_files) {
const std::string& key = "extra/" + kv.first;
at::DataPtr meta_ptr;
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
extra_files[kv.first] =
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
}
loadTensorTable(&model_def);
// TODO: this can be simplified when C++/Python interop lands,
// and the submodules would be created as the same in either C++ or Python
convertModule(module_def);
}
void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
std::unordered_map<std::string, at::Storage> storageMap;
for (const torch::TensorDef& tensor : model_def->tensors()) {
tensor_table_.emplace_back(loadTensor(tensor, storageMap));
}
}
at::Tensor ScriptModuleDeserializer::loadTensor(
const torch::TensorDef& tensor_proto,
std::unordered_map<std::string, at::Storage>& storageMap) {
std::vector<int64_t> dims(
tensor_proto.dims().begin(), tensor_proto.dims().end());
std::vector<int64_t> strides(
tensor_proto.strides().begin(), tensor_proto.strides().end());
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
const std::string& record_key = tensor_proto.data().key();
AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
at::Device device(tensor_proto.device());
if (device_.has_value()) {
// override the device, if user provides map_location
device = device_.value();
}
auto storage_it = storageMap.find(record_key);
if (storage_it == storageMap.end()) {
at::DataPtr storage_ptr;
uint64_t record_size;
std::tie(storage_ptr, record_size) = reader_.getRecord(record_key);
auto cpu_storage = at::Storage(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
record_size / at::CPU(type).typeMeta().itemsize(),
nullptr); // NB: we didn't set any allocator for the tensor
if (device.type() == at::DeviceType::CPU) {
storage_it =
storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
} else if (device.type() == at::DeviceType::CUDA) {
at::Tensor cpu_tensor =
at::empty({0}, at::CPU(type).options()).set_(cpu_storage);
at::Storage cuda_storage =
cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
storage_it =
storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
} else {
AT_ERROR(
"supported devices include CPU and CUDA, however got ",
at::DeviceTypeName(device.type(), false));
}
}
if (storage_it->second.device().type() != device.type() ||
(device.has_index() &&
storage_it->second.device().index() != device.index())) {
std::stringstream oss;
oss << "storage previously was specified with device "
<< storage_it->second.device() << "but now is specified with device "
<< device << std::endl;
AT_ERROR(oss.str());
}
at::Tensor result;
if (device.type() == at::DeviceType::CPU) {
result =
at::empty({0}, at::CPU(type).options())
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
} else if (device.type() == at::DeviceType::CUDA) {
result =
at::empty({0}, at::CUDA(type).options())
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
}
AT_ASSERT(result.defined());
result = autograd::make_variable(result, tensor_proto.requires_grad());
return result;
}
void ScriptModuleDeserializer::convertModule(
const torch::ModuleDef& module_def) {
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
module->set_optimized(module_def.optimize());
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
moduleStack_.emplace_back(sub_def.name());
convertModule(sub_def);
moduleStack_.pop_back();
}
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
module->register_parameter(param_def.name(), tensor, param_def.is_buffer());
}
if (module_def.has_torchscript_arena()) {
at::DataPtr data;
size_t size;
std::tie(data, size) =
reader_.getRecord(module_def.torchscript_arena().key());
std::string data_str(static_cast<const char*>(data.get()), size);
import_methods(module, data_str, tensor_table_);
}
}
} // namespace
void import_ir_module(
ModuleLookup module_lookup,
std::istream& in,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
ScriptModuleDeserializer deserializer(&in);
deserializer.deserialize(module_lookup, device, extra_files);
}
void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
ScriptModuleDeserializer deserializer(filename);
deserializer.deserialize(module_lookup, device, extra_files);
}
void import_ir_module(
ModuleLookup module_lookup,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device, extra_files);
}
std::shared_ptr<script::Module> load(
std::istream& in,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
std::unique_ptr<IStreamAdapter> rai =
caffe2::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device, extra_files);
return module;
}
std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device, extra_files);
return module;
}
std::shared_ptr<script::Module> load(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
script::ExtraFilesMap& extra_files) {
auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
std::shared_ptr<script::Module> curr = module;
for (const auto& name : qualified_name) {
if (curr->find_module(name) == nullptr) {
curr->register_module(name, std::make_shared<script::Module>());
}
curr = curr->get_module(name);
}
return curr;
};
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device, extra_files);
return module;
}
} // namespace jit
} // namespace torch