blob: 4d870a56d59e96eba8b277a9567fd4051e705a26 [file] [log] [blame]
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include "torch/csrc/jit/export.h"
#include "torch/csrc/autograd/symbolic.h"
#include "torch/csrc/onnx/onnx.h"
#include "torch/csrc/utils/functional.h"
#include <torch/csrc/jit/assertions.h>
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/inline_container.h"
#include "onnx/onnx_pb.h"
#include <ATen/ATen.h>
#include "c10/util/Optional.h"
#include <fstream>
#include <memory>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
namespace torch { namespace jit {
namespace {
namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
class ScriptModuleSerializer;
std::string getExportableSchemaStringForMethod(const script::Method& method) {
const auto& schema = method.getSchema();
for (const auto& argument : schema.arguments()) {
AT_CHECK(
!argument.default_value(),
"Default arguments in script graphs may currently not be exported.");
}
std::ostringstream stream;
stream << schema;
return stream.str();
}
std::string getNodeStackTraceString(const Node* n) {
std::stringstream ss;
if (n->getSourceLocation()) {
n->getSourceLocation()->highlight(ss);
} else {
ss << "<unknown location>";
}
return ss.str();
}
void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_type) {
for (auto node : b->nodes()) {
for (Block *sub_block : node->blocks()) {
validateBlock(sub_block, operator_export_type);
}
// Macro'ed so we get a marginally better line number on failed export
#define FAIL_EXPORT(name) \
throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
IR_IF(node, PythonOp)
auto py_node = static_cast<torch::jit::PythonOp*>(value);
FAIL_EXPORT(
"Couldn't export Python operator " + py_node->name() +
"\n\nDefined at:\n" + getNodeStackTraceString(node))
IR_ELSE()
// Special error messages for certain types of operators
if (node->kind() == aten::expand) {
if (operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
WithInsertPoint guard(node);
auto* new_node = b->owningGraph()->insertNode(
b->owningGraph()->create(Symbol(::torch::jit::onnx::ATen), node->inputs(), node->outputs().size()));
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->replaceAllUsesWith(new_node->output(i));
}
new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
} else {
FAIL_EXPORT(
"Could not export a broadcasted operation; ONNX likely does not support this form of broadcasting.\n\nBroadcast occurred at:\n" +
getNodeStackTraceString(node));
}
}
if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
FAIL_EXPORT(
"Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
getNodeStackTraceString(node));
}
bool is_aten_enabled = operator_export_type ==
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
if (!node->kind().is_onnx() && !is_aten_enabled &&
node->kind() != prim::Undefined) {
FAIL_EXPORT(
"Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" +
getNodeStackTraceString(node));
}
IR_END()
#undef FAIL_EXPORT
}
}
void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
validateBlock(graph->block(), operator_export_type);
EliminateDeadCode(graph);
}
class EncoderBase {
public:
EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc);
onnx::ModelProto get_model_proto() {
return model_proto_;
}
protected:
void EncodeGraph(onnx::GraphProto *graph_proto,
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers = {});
void EncodeBlock(onnx::GraphProto *graph_proto,
const Block *block,
const std::vector<at::Tensor> &initializers = {});
virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) = 0;
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
const Value* n) {};
virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
onnx::ValueInfoProto* v,
const Value* n);
void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name);
onnx::ModelProto model_proto_;
size_t num_blocks_;
onnx_torch::OperatorExportTypes operator_export_type_;
bool strip_doc_;
};
onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
switch(at_type) {
case at::kDouble:
return onnx::TensorProto_DataType_DOUBLE;
case at::kFloat:
return onnx::TensorProto_DataType_FLOAT;
case at::kHalf:
return onnx::TensorProto_DataType_FLOAT16;
case at::kByte:
return onnx::TensorProto_DataType_UINT8;
case at::kChar:
return onnx::TensorProto_DataType_INT8;
case at::kShort:
return onnx::TensorProto_DataType_INT16;
case at::kInt:
return onnx::TensorProto_DataType_INT32;
case at::kLong:
return onnx::TensorProto_DataType_INT64;
default:
AT_ERROR("unexpected tensor scalar type");
}
}
EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc)
: num_blocks_(0),
operator_export_type_(operator_export_type),
strip_doc_(strip_doc) {
model_proto_.set_producer_name("pytorch");
model_proto_.set_ir_version(onnx::IR_VERSION);
model_proto_.set_producer_version("0.4");
}
void EncoderBase::EncodeValueInfo(
onnx::GraphProto *graph_proto,
onnx::ValueInfoProto* v,
const Value* n) {
v->set_name(n->uniqueName());
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
const std::vector<std::int64_t>& sizes = node_type->sizes();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
shape->mutable_dim(i)->set_dim_value(sizes[i]);
}
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
} else {
tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
}
}
void EncoderBase::EncodeGraph(
onnx::GraphProto *graph_proto,
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers) {
EncodeBlock(graph_proto, graph->block(), initializers);
}
void EncoderBase::EncodeBlock(
onnx::GraphProto *graph_proto, const Block *block,
const std::vector<at::Tensor> &initializers) {
JIT_ASSERT(graph_proto != nullptr);
std::string block_name = "torch-jit-export";
if (num_blocks_) {
block_name += std::to_string(num_blocks_);
}
num_blocks_++;
graph_proto->set_name(block_name);
for (auto input : block->inputs()) {
onnx::ValueInfoProto* v = graph_proto->add_input();
EncodeValueInfo(graph_proto, v, input);
}
for (auto output : block->outputs()) {
onnx::ValueInfoProto* v = graph_proto->add_output();
EncodeValueInfo(graph_proto, v, output);
}
for (auto node : block->nodes()) {
bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
if (node->kind() == prim::Undefined && !is_raw_export) {
// Undefined nodes are used to implement optional inputs. One
// way to "not provide" an optional input is to create an
// Undefined node, and pass its output as that input.
continue;
}
auto p_n = graph_proto->add_node();
if (node->getSourceLocation() && !strip_doc_) {
std::stringstream ss;
node->getSourceLocation()->highlight(ss);
p_n->set_doc_string(ss.str());
}
for(auto input : node->inputs()) {
if (input->node()->kind() == prim::Undefined && !is_raw_export) {
p_n->add_input("");
} else {
p_n->add_input(input->uniqueName());
}
}
for(auto output : node->outputs()) {
p_n->add_output(output->uniqueName());
EncodeIntermediateValueInfo(graph_proto, output);
}
if (is_raw_export) {
JIT_ASSERT(!node->kind().is_onnx());
p_n->set_domain(node->kind().domainString());
}
else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
JIT_ASSERT(node->kind().is_onnx());
}
p_n->set_op_type(node->kind().toUnqualString());
for(auto attr_name : node->attributeNames()) {
AddAttribute(p_n, node, attr_name);
}
if (is_raw_export && node->blocks().size() > 0) {
auto blocks = p_n->add_attribute();
blocks->set_name("_blocks");
blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
for (auto block : node->blocks()) {
auto graph = blocks->add_graphs();
EncodeBlock(graph, block, initializers);
}
}
if (node->kind() == torch::jit::onnx::Loop) {
JIT_ASSERT(node->blocks().size() == 1);
auto body = p_n->add_attribute();
body->set_name("body");
body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto g = body->mutable_g();
EncodeBlock(g, node->blocks()[0]);
}
if (node->kind() == torch::jit::onnx::If) {
JIT_ASSERT(node->blocks().size() == 2);
auto true_branch = p_n->add_attribute();
true_branch->set_name("then_branch");
true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto true_g = true_branch->mutable_g();
EncodeBlock(true_g, node->blocks()[0]);
auto false_branch = p_n->add_attribute();
false_branch->set_name("else_branch");
false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto false_g = false_branch->mutable_g();
EncodeBlock(false_g, node->blocks()[1]);
}
}
auto num_initializers = initializers.size();
JIT_ASSERT(block->inputs().size() >= num_initializers);
size_t inputs_count = block->inputs().size() - num_initializers;
for (auto & tensor : initializers) {
// TODO: stop using positions to determine which initializers
// match to which inputs
std::string name = graph_proto->input(inputs_count++).name();
auto p = graph_proto->add_initializer();
p->set_name(name);
EncodeTensor(p, tensor, name);
}
}
void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) {
auto attr = node_proto->add_attribute();
JIT_ASSERT(name.is_attr());
attr->set_name(name.toUnqualString());
switch(node->kindOf(name)) {
case AttributeKind::f:
attr->set_f(node->f(name));
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
break;
case AttributeKind::fs:
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
for(auto & v : node->fs(name))
attr->add_floats(v);
break;
case AttributeKind::i:
attr->set_type(onnx::AttributeProto_AttributeType_INT);
attr->set_i(node->i(name));
break;
case AttributeKind::is:
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
for(auto & v : node->is(name))
attr->add_ints(v);
break;
case AttributeKind::s:
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
attr->set_s(node->s(name));
break;
case AttributeKind::ss:
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
for(auto & v : node->ss(name))
attr->add_strings(v);
break;
case AttributeKind::t: {
attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
auto t = attr->mutable_t();
EncodeTensor(t, node->t(name));
} break;
case AttributeKind::ts:
attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
for(auto & v : node->ts(name)) {
auto t = attr->add_tensors();
EncodeTensor(t, v);
}
break;
case AttributeKind::g: {
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto g = attr->mutable_g();
EncodeGraph(g, node->g(name));
} break;
case AttributeKind::gs:
attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
for(auto & v : node->gs(name)) {
auto g = attr->add_graphs();
EncodeGraph(g, v);
}
break;
default:
throw std::runtime_error("unexpected attribute kind");
}
}
class GraphEncoder: public EncoderBase {
public:
GraphEncoder(const std::shared_ptr<Graph> &graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
const std::vector<at::Tensor> &initializers,
bool defer_weight_export,
bool strip_doc);
RawDataExportMap get_raw_data_export_map() {
return raw_data_export_map_;
}
private:
void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) override;
RawDataExportMap raw_data_export_map_;
bool defer_weight_export_;
};
GraphEncoder::GraphEncoder(
const std::shared_ptr<Graph> &graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
const std::vector<at::Tensor> &initializers,
bool defer_weight_export,
bool strip_doc)
: EncoderBase(operator_export_type, strip_doc),
defer_weight_export_(defer_weight_export) {
if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
validateGraph(graph, operator_export_type);
}
auto* imp = model_proto_.add_opset_import();
// This is the version of ONNX operator set we are targeting
imp->set_version(onnx_opset_version);
EncodeGraph(model_proto_.mutable_graph(), graph, initializers);
}
void GraphEncoder::EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
for(auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
// CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
auto t = tensor.contiguous().cpu();
// Add a buffer to the raw_data_export_map for the caller to dump into an
// external data store. If external_ref is not specified, we instead dump
// the contiguous data into the protobuf itself
if (defer_weight_export_ && external_ref) {
// For now, we use the name of the tensor as the external lookup name to
// avoid ONNX protobuf changes.
JIT_ASSERT(external_ref.value() == tensor_proto->name());
JIT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
raw_data_export_map_[external_ref.value()] = t;
tensor_proto->set_raw_data("__EXTERNAL");
} else {
JIT_ASSERT(t.is_contiguous());
tensor_proto->set_raw_data(std::string(static_cast<char*>(t.data_ptr()), t.type().elementSizeInBytes() * t.numel()));
}
}
class MethodEncoder : public EncoderBase {
public:
MethodEncoder(
const script::Method& method,
const ScriptModuleSerializer& serializer);
std::string EncodeMethod(
const script::Method& method,
const std::string& prefix);
private:
void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) override;
void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
const Value* n) override;
void EncodeValueInfo(onnx::GraphProto *graph_proto,
onnx::ValueInfoProto* v,
const Value* n) override;
void EncodeTypeInfo(onnx::GraphProto *graph_proto,
onnx::ValueInfoProto* v,
const TypePtr& type,
const std::string& name);
// serializer already serialized all the tensors, and stores
// the tensor and parameter tables
const ScriptModuleSerializer* serializer_;
// Used to create sequential dummy names for node types
size_t type_counter_ = 0;
};
// this is a serializer class which saves script modules to pt files. the
// content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h. all the records except the last
// one are tensor data, and the last record is a serialized ModelProto, defined
// in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
// model, and it is serialized as json.
class ScriptModuleSerializer final {
public:
ScriptModuleSerializer(const std::string& filename);
ScriptModuleSerializer(std::ostream* ofs);
void serialize(const script::Module& module);
uint64_t lookupTensorId(const at::Tensor* tensor) const;
const std::string& lookupParamName(const at::Tensor* tensor) const;
private:
void convertToModel(const script::Module& module, torch::ModelDef* model_def);
// add a tensor to the tensorTable
void addTensor(const at::Tensor* tensor);
// recursively collect the tensors in a block and add them to the tensorTable
void findTensorInBlock(const Block& block);
// recursively iterate over the whole module to collect the information of
// tensors and parameters
void collectInfo(const script::Module& module, const std::string& prefix);
// write the content of the tensor to the file/stream, and save the
// offset in the storageMap_
void convertAndWriteTensor(
const at::Tensor& tensor,
caffe2::TensorProto* tensor_proto);
// dump all the tensors in the tensorTable_ to a ModelDef (metadata) and
// the file/stream (the content), assuming all the information of the
// tensors has been collected. the method calls convertAndWriteTensor
// to dump the content of a tensor
void writeTensorTable(torch::ModelDef* model_def);
void convertModule(
const script::Module& module,
const std::string& name,
torch::ModuleDef* module_def);
void convertParameter(
const script::NamedParameter& param,
torch::ParameterDef* param_def);
void convertMethod(
const script::Method& method,
torch::MethodDef* method_def);
std::ofstream ofs_;
PyTorchStreamWriter writer_;
// storage_ptr => record_offset
std::unordered_map<const void*, uint64_t> storageMap_;
// tensor => param name
std::unordered_map<const at::Tensor*, std::string> paramMap_;
// tensor => tensor_id
std::unordered_map<const at::Tensor*, uint64_t> tensorTable_;
// used for generating table id for tensors
uint64_t nextTensorId_ = 0;
};
// MethodEncoder's methods
MethodEncoder::MethodEncoder(
const script::Method& method,
const ScriptModuleSerializer& serializer)
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false) {
serializer_ = &serializer;
}
std::string MethodEncoder::EncodeMethod(
const script::Method& method,
const std::string& prefix) {
onnx::ModelProto model_proto;
model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
auto* node_proto = model_proto.mutable_graph()->add_node();
node_proto->set_name(prefix + method.name());
// We store the schema string in the docstring.
node_proto->set_doc_string(getExportableSchemaStringForMethod(method));
// Store member_inputs of Method in input
for (auto& member_input : method.params()) {
const auto& param_name = serializer_->lookupParamName(member_input);
node_proto->add_input(param_name);
}
auto attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
for (auto node : method.graph()->nodes()) {
if (node->kind() == prim::PythonOp) {
auto py_node = static_cast<torch::jit::PythonOp*>(node);
throw std::runtime_error(
"Couldn't export Python operator " + py_node->name() +
"\n\nDefined at:\n" + getNodeStackTraceString(node));
}
}
EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {});
std::string torch_script;
AT_ASSERT(model_proto.SerializeToString(&torch_script));
return torch_script;
}
void MethodEncoder::EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
uint64_t tensor_id = serializer_->lookupTensorId(&tensor);
tensor_proto->set_name(c10::to_string(tensor_id));
// No need to store the content of the tensor to the file/stream
// any more, since it is already saved at the beginning of the
// serialization in writeTensorTable
}
void MethodEncoder::EncodeIntermediateValueInfo(
onnx::GraphProto* graph_proto,
const Value* n) {
auto v = graph_proto->add_value_info();
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
}
c10::optional<std::string> getBaseTypeDenotation(TypeKind& kind) {
if (kind == TypeKind::NumberType) {
return "NumberType";
} else if (kind == TypeKind::FloatType) {
return "FloatType";
} else if (kind == TypeKind::IntType) {
return "IntType";
} else if (kind == TypeKind::BoolType) {
return "BoolType";
} else if (kind == TypeKind::NoneType) {
return "NoneType";
} else if (kind == TypeKind::GeneratorType) {
return "GeneratorType";
} else if (kind == TypeKind::StringType) {
return "StringType";
}
return c10::nullopt;
}
void MethodEncoder::EncodeTypeInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const TypePtr& type,
const std::string& name) {
v->set_name(name);
onnx::TypeProto* type_proto = v->mutable_type();
onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type();
onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape();
// Use TypeProto fields to encode types.
// denotation stores the type as a string
auto kind = type->kind();
if (kind == TypeKind::DynamicType) {
type_proto->set_denotation("DynamicType");
tensortype_proto->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
} else if (kind == TypeKind::TensorType) {
type_proto->set_denotation("TensorType");
// encode the number of dimensions by pushing that number of ones into the shape proto
auto tensor_type = type->expect<TensorType>();
for (int i = 0; i < tensor_type->dim(); i++) {
shape_proto->add_dim();
shape_proto->mutable_dim(i)->set_dim_value(1);
}
tensortype_proto->set_elem_type(ATenTypeToOnnxType(tensor_type->scalarType()));
} else if (kind == TypeKind::CompleteTensorType) {
type_proto->set_denotation("CompleteTensorType");
CompleteTensorTypePtr node_type = type->cast<CompleteTensorType>();
// store the sizes and strides in the dims field of TensorShapeProto
size_t i = 0;
for (auto &size : node_type->sizes()) {
shape_proto->add_dim();
shape_proto->mutable_dim(i)->set_dim_value(size);
i++;
}
for (auto &stride : node_type->strides()) {
shape_proto->add_dim();
shape_proto->mutable_dim(i)->set_dim_value(stride);
i++;
}
tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
} else if (kind == TypeKind::TupleType) {
type_proto->set_denotation("TupleType");
TupleTypePtr node_type = type->cast<TupleType>();
auto elements = node_type->elements();
// Generate a name for and encode each subtype in the value_info field of the GraphProto.
for (size_t i = 0; i < elements.size(); i++) {
std::string name = "#" + std::to_string(type_counter_++);
shape_proto->add_dim();
shape_proto->mutable_dim(i)->set_dim_param(name);
onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name);
}
} else if (kind == TypeKind::ListType) {
type_proto->set_denotation("ListType");
ListTypePtr node_type = type->cast<ListType>();
// Generate a name for and encode the subtype in the value_info field of the GraphProto.
std::string name = "#" + std::to_string(type_counter_++);
shape_proto->add_dim();
shape_proto->mutable_dim(0)->set_dim_param(name);
onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
} else if (kind == TypeKind::VarType) {
type_proto->set_denotation("TypeVar:" + type->expect<VarType>()->name());
} else if (kind == TypeKind::OptionalType) {
type_proto->set_denotation("OptionalType");
OptionalTypePtr node_type = type->cast<OptionalType>();
// Generate a name for and encode each subtype in the value_info field of the GraphProto.
std::string name = "#" + std::to_string(type_counter_++);
shape_proto->add_dim();
shape_proto->mutable_dim(0)->set_dim_param(name);
onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
} else {
auto denotation = getBaseTypeDenotation(kind);
if (!denotation) {
throw std::runtime_error("unexpected type kind");
}
type_proto->set_denotation(*denotation);
}
}
void MethodEncoder::EncodeValueInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const Value* n) {
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
}
// ScriptModuleSerializer's methods
ScriptModuleSerializer::ScriptModuleSerializer(const std::string& filename)
: ofs_(
filename,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary),
writer_(&ofs_) {
// TODO appropriate support for mmap, right now we still use stream writer
}
ScriptModuleSerializer::ScriptModuleSerializer(std::ostream* ofs)
: ofs_(), writer_(ofs) {}
void ScriptModuleSerializer::serialize(const script::Module& module) {
torch::ModelDef model_def;
convertToModel(module, &model_def);
std::string output;
// NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
// be consistent with MessageToJsonString
std::string url_prefix = "type.googleapis.com";
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
::google::protobuf::util::NewTypeResolverForDescriptorPool(
url_prefix, model_def.GetDescriptor()->file()->pool()));
::google::protobuf::util::Status convert_result =
::google::protobuf::util::BinaryToJsonString(
resolver.get(),
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
model_def.SerializeAsString(),
&output);
if (!convert_result.ok()) {
std::stringstream ss;
ss << convert_result;
AT_ERROR(ss.str());
}
auto record_id = writer_.writeRecord(output.data(), output.size());
AT_ASSERT(record_id != 0);
writer_.writeEndOfFile();
}
uint64_t ScriptModuleSerializer::lookupTensorId(
const at::Tensor* tensor) const {
auto it = tensorTable_.find(tensor);
AT_ASSERT(it != tensorTable_.end());
return it->second;
}
const std::string& ScriptModuleSerializer::lookupParamName(
const at::Tensor* tensor) const {
auto it = paramMap_.find(tensor);
AT_ASSERT(it != paramMap_.end());
return it->second;
}
void ScriptModuleSerializer::convertToModel(
const script::Module& module,
torch::ModelDef* model_def) {
model_def->set_name("script-model");
model_def->set_producer_name("pytorch");
model_def->set_producer_version("1.0"); // TODO: set the producer version
// using appropriate function call
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
std::string main_module_name = "";
nextTensorId_ = 0;
collectInfo(module, main_module_name);
writeTensorTable(model_def);
convertModule(module, main_module_name, model_def->mutable_main_module());
}
void ScriptModuleSerializer::addTensor(const at::Tensor* tensor) {
if (tensorTable_.find(tensor) == tensorTable_.end()) {
tensorTable_[tensor] = nextTensorId_;
++nextTensorId_;
}
}
void ScriptModuleSerializer::findTensorInBlock(const Block& block) {
for (auto node : block.nodes()) {
for (auto attr_name : node->attributeNames()) {
AT_ASSERT(attr_name.is_attr());
switch (node->kindOf(attr_name)) {
case AttributeKind::f:
case AttributeKind::fs:
case AttributeKind::i:
case AttributeKind::is:
case AttributeKind::s:
case AttributeKind::ss:
break;
case AttributeKind::t: {
const at::Tensor* tensor = &node->t(attr_name);
addTensor(tensor);
} break;
case AttributeKind::ts: {
for (auto& v : node->ts(attr_name)) {
const at::Tensor* tensor = &v;
addTensor(tensor);
}
} break;
case AttributeKind::g: {
findTensorInBlock(*node->g(attr_name)->block());
} break;
case AttributeKind::gs: {
for (auto& v : node->gs(attr_name)) {
findTensorInBlock(*v->block());
}
} break;
default:
AT_ERROR("unexpected attribute kind");
}
}
for (auto b : node->blocks()) {
findTensorInBlock(*b);
}
}
}
void ScriptModuleSerializer::collectInfo(
const script::Module& module,
const std::string& prefix) {
for (const auto& elem : module.get_parameters()) {
const script::NamedParameter& param = elem.value();
paramMap_[param.slot()] = prefix + param.name;
addTensor(param.slot());
}
for (const auto& elem : module.get_methods()) {
findTensorInBlock(*elem.value()->graph()->block());
}
for (const auto& elem : module.get_modules()) {
collectInfo(*elem->module, prefix + elem.key() + ".");
}
}
void ScriptModuleSerializer::convertAndWriteTensor(
const at::Tensor& tensor,
caffe2::TensorProto* tensor_proto) {
auto tensor_it = tensorTable_.find(&tensor);
AT_ASSERT(tensor_it != tensorTable_.end());
tensor_proto->set_name(c10::to_string(tensor_it->second));
for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
at::scalarTypeToTypeMeta(tensor.type().scalarType())));
tensor_proto->set_storage_type(caffe2::TensorProto_StorageType_EXTERNAL);
caffe2::ExternalDataProto* external_data =
tensor_proto->mutable_external_data();
for (auto s : tensor.strides()) {
external_data->add_strides(s);
}
external_data->set_offset(tensor.storage_offset());
uint64_t record_size =
tensor.type().elementSizeInBytes() * tensor.storage().size();
external_data->set_record_size(record_size);
auto* key = tensor.storage().unsafeGetStorageImpl();
auto storage_it = storageMap_.find(key);
if (storage_it == storageMap_.end()) {
// TODO HIP support
uint64_t record_id;
if (tensor.storage().device_type() == at::DeviceType::CUDA) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
at::Tensor t = at::getType(tensor)
._th_tensor(
tensor.storage(),
/* storageOffset = */ 0,
/* size = */
{static_cast<int64_t>(tensor.storage().size())},
/* stride = */ {1})
.cpu();
AT_ASSERT(
t.type().elementSizeInBytes() * t.storage().size() == record_size);
record_id = writer_.writeRecord(
t.storage().data(),
t.type().elementSizeInBytes() * t.storage().size());
} else {
record_id = writer_.writeRecord(tensor.storage().data(), record_size);
}
external_data->set_record_id(c10::to_string(record_id));
storageMap_[key] = record_id;
} else {
external_data->set_record_id(c10::to_string(storage_it->second));
}
// TODO handle device case, set the device_detail and load to CUDA device
}
void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
// NB: we don't reserve any order for tensors in the tensorTable_
for (const auto& kv : tensorTable_) {
auto* tensor_proto = model_def->add_tensors();
convertAndWriteTensor(*kv.first, tensor_proto);
}
}
void ScriptModuleSerializer::convertModule(
const script::Module& module,
const std::string& name,
torch::ModuleDef* module_def) {
module_def->set_name(name);
module_def->set_optimize(module.is_optimized());
for (const auto& elem : module.get_parameters()) {
torch::ParameterDef* param_def = module_def->add_parameters();
convertParameter(elem.value(), param_def);
}
for (auto& elem : module.get_methods()) {
torch::MethodDef* method_def = module_def->add_methods();
convertMethod(*elem.value(), method_def);
}
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem->module, elem.key(), sub_def);
}
}
void ScriptModuleSerializer::convertParameter(
const script::NamedParameter& param,
torch::ParameterDef* param_def) {
param_def->set_name(param.name);
param_def->set_is_buffer(param.is_buffer);
param_def->set_require_gradient(param.slot()->requires_grad());
auto it = tensorTable_.find(param.slot());
AT_ASSERT(it != tensorTable_.end());
param_def->set_tensor_id(c10::to_string(it->second));
}
void ScriptModuleSerializer::convertMethod(
const script::Method& method,
torch::MethodDef* method_def) {
// TODO encode the real torch script instead of ModelProto
MethodEncoder encoder(method, *this);
// we already keep the tree structure in the top level module,
// so pass "" as prefix
std::string torch_script = encoder.EncodeMethod(method, "");
method_def->set_onnx_proto(torch_script);
}
// Pretty printing
constexpr char indent_char = ' ';
constexpr size_t indent_multiplier = 2;
std::string idt(size_t indent) {
return std::string(indent * indent_multiplier, indent_char);
}
std::string nlidt(size_t indent) {
return std::string("\n") + idt(indent);
}
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
stream << "TensorProto shape: [";
for (int i = 0; i < tensor.dims_size(); ++i) {
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
}
stream << "]";
}
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
for (int i = 0; i < shape.dim_size(); ++i) {
auto &dim = shape.dim(i);
if (dim.has_dim_value()) {
stream << dim.dim_value();
} else {
stream << "?";
}
stream << (i == shape.dim_size() - 1 ? "" : " ");
}
}
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
stream << "Tensor dims: ";
dump(tensor_type.shape(), stream);
}
void dump(const onnx::TypeProto& type, std::ostream& stream) {
dump(type.tensor_type(), stream);
}
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
stream << "{name: \"" << value_info.name()
<< "\", type:";
dump(value_info.type(), stream);
stream << "}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) {
stream << "{ name: '" << attr.name() << "', type: ";
if (attr.has_f()) {
stream << "float, value: " << attr.f();
} else if (attr.has_i()) {
stream << "int, value: " << attr.i();
} else if (attr.has_s()) {
stream << "string, value: '" << attr.s() << "'";
} else if (attr.has_g()) {
stream << "graph, value:\n";
dump(attr.g(), stream, indent+1);
stream << nlidt(indent);
} else if (attr.has_t()) {
stream << "tensor, value:";
dump(attr.t(), stream);
} else if (attr.floats_size()) {
stream << "floats, values: [";
for (int i = 0; i < attr.floats_size(); ++i)
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.ints_size()) {
stream << "ints, values: [";
for (int i = 0; i < attr.ints_size(); ++i)
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.strings_size()) {
stream << "strings, values: [";
for (int i = 0; i < attr.strings_size(); ++i)
stream << "'" << attr.strings(i) << "'" << (i == attr.strings_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.tensors_size()) {
stream << "tensors, values: [";
for (auto& t : attr.tensors()) {
dump(t, stream);
}
stream << "]";
} else if (attr.graphs_size()) {
stream << "graphs, values: [";
for (auto& g : attr.graphs()) {
dump(g, stream, indent+1);
}
stream << "]";
} else {
stream << "UNKNOWN";
}
stream << "}";
}
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
for (int i = 0; i < node.input_size(); ++i) {
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
}
stream << "], outputs: [";
for (int i = 0; i < node.output_size(); ++i) {
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
}
stream << "], attributes: [";
for (int i = 0; i < node.attribute_size(); ++i) {
dump(node.attribute(i), stream, indent+1);
stream << (i == node.attribute_size() - 1 ? "" : ",");
}
stream << "]}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
stream << idt(indent) << "GraphProto {" << nlidt(indent+1)
<< "name: \"" << graph.name() << "\"" << nlidt(indent+1)
<< "inputs: [";
for (int i = 0; i < graph.input_size(); ++i) {
dump(graph.input(i), stream);
stream << (i == graph.input_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "outputs: [";
for (int i = 0; i < graph.output_size(); ++i) {
dump(graph.output(i), stream);
stream << (i == graph.output_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "initializers: [";
for (int i = 0; i < graph.initializer_size(); ++i) {
dump(graph.initializer(i), stream);
stream << (i == graph.initializer_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent+1)
<< "nodes: [" << nlidt(indent+2);
for (int i = 0; i < graph.node_size(); ++i) {
dump(graph.node(i), stream, indent+2);
if (i != graph.node_size() - 1) stream << "," << nlidt(indent+2);
}
stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n";
}
void dump(const onnx::OperatorSetIdProto& operator_set_id, std::ostream& stream) {
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
}
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
stream << idt(indent)
<< "ModelProto {" << nlidt(indent+1)
<< "producer_name: \"" << model.producer_name() << "\"" << nlidt(indent+1)
<< "domain: \"" << model.domain() << "\"" << nlidt(indent+1)
<< "doc_string: \"" << model.doc_string() << "\"";
if (model.has_graph()) {
stream << nlidt(indent+1) << "graph:\n";
dump(model.graph(), stream, indent+2);
}
if (model.opset_import_size()) {
stream << idt(indent+1) << "opset_import: [";
for (auto &opset_imp : model.opset_import()) {
dump(opset_imp, stream);
}
stream << "],\n";
}
stream << idt(indent) << "}\n";
}
std::string prettyPrint(const onnx::ModelProto& model) {
std::stringstream ss;
dump(model, ss, 0);
return ss.str();
}
} // namespace
std::string PrettyPrintExportedGraph(
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool google_printer) {
auto graph_encoder = GraphEncoder(
graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, true);
if (google_printer) {
return graph_encoder.get_model_proto().DebugString();
}
return prettyPrint(graph_encoder.get_model_proto());
}
// export_raw_ir will export IR ops without turning them into ONNX ops.
// The output will use the ONNX protobuf format, but the ops will not
// conform to the ONNX op specification. Thus, the output will not
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
std::tuple<std::string, RawDataExportMap> ExportGraph(
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
auto graph_encoder = GraphEncoder(
graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, false);
return std::make_tuple(graph_encoder.get_model_proto().SerializeAsString(),
graph_encoder.get_raw_data_export_map());
}
void ExportModule(const script::Module& module, std::ostream& out) {
ScriptModuleSerializer serializer(&out);
serializer.serialize(module);
}
void ExportModule(const script::Module& module, const std::string &filename) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module);
}
}}