blob: 49ddf09286aa8de64b8daad68af970ba1d17e4b8 [file] [log] [blame]
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include <torch/csrc/autograd/symbolic.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/onnx/onnx.h>
#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/instruction.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <caffe2/proto/torch_pb.h>
#include <caffe2/serialize/inline_container.h>
#include <onnx/onnx_pb.h>
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <fstream>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
namespace torch {
namespace jit {
char const * toString(OpCode op);
namespace {
namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
namespace {
ExportModuleExtraFilesHook& GetExtraFilesHook() {
static ExportModuleExtraFilesHook func = nullptr;
return func;
};
}
class ScriptModuleSerializer;
std::string getNodeStackTraceString(const Node* n) {
return n->sourceRange().str();
}
void validateBlock(
Block* b,
onnx_torch::OperatorExportTypes operator_export_type) {
for (auto node : b->nodes()) {
for (Block* sub_block : node->blocks()) {
validateBlock(sub_block, operator_export_type);
}
// Macro'ed so we get a marginally better line number on failed export
#define FAIL_EXPORT(name) \
throw std::runtime_error( \
std::string("ONNX export failed: ") + name + \
"\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
if (node->kind() == prim::PythonOp) {
auto py_node = static_cast<PythonOp*>(node);
FAIL_EXPORT(
"Couldn't export Python operator " + py_node->name() +
"\n\nDefined at:\n" + getNodeStackTraceString(node))
} else {
// Special error messages for certain types of operators
if (node->kind() == aten::expand) {
if (operator_export_type ==
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
WithInsertPoint guard(node);
auto* new_node =
b->owningGraph()->insertNode(b->owningGraph()->create(
Symbol(::c10::onnx::ATen),
node->inputs(),
node->outputs().size()));
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->replaceAllUsesWith(new_node->output(i));
}
new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
}
}
if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
FAIL_EXPORT(
"Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
getNodeStackTraceString(node));
}
bool is_aten_enabled = operator_export_type ==
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
if (!node->kind().is_onnx() && !node->kind().is_caffe2() &&
!is_aten_enabled && !node->mustBeNone()) {
FAIL_EXPORT(
"Couldn't export operator " + node->kind().toDisplayString() +
"\n\nDefined at:\n" + getNodeStackTraceString(node));
}
}
#undef FAIL_EXPORT
}
}
void validateGraph(
const std::shared_ptr<Graph>& graph,
onnx_torch::OperatorExportTypes operator_export_type) {
validateBlock(graph->block(), operator_export_type);
// this is run on an onnx graph which doesn't have side effects.
// ignore side effects in dead code elimination.
EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
class EncoderBase {
public:
EncoderBase(
onnx_torch::OperatorExportTypes operator_export_type,
bool strip_doc);
onnx::ModelProto get_model_proto() {
return model_proto_;
}
protected:
// Using std::map instead of std::unordered_map for initializers
// in EncodeGraph cosntructor so that the order in which initializers
// get written to the ONNX graph is always the deterministic and
// predictable. While this is not a ONNX requirement, it is needed
// for testing purposes in tests that use _export_to_pretty_string()
// for validating ONNX graphs.
void EncodeGraph(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers =
std::map<std::string, at::Tensor>(),
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
bool keep_initializers_as_inputs = true);
void EncodeBlock(
onnx::GraphProto* graph_proto,
const Block* block,
const std::map<std::string, at::Tensor>& initializers =
std::map<std::string, at::Tensor>(),
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
bool keep_initializers_as_inputs = true);
virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) = 0;
virtual void EncodeIntermediateValueInfo(
onnx::GraphProto* graph_proto,
const Value* n){}
virtual void EncodeValueInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const Value* n,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>());
void AddAttribute(
onnx::NodeProto* node_proto,
const jit::Node* node,
const jit::Symbol name);
onnx::ModelProto model_proto_;
size_t num_blocks_;
onnx_torch::OperatorExportTypes operator_export_type_;
bool strip_doc_;
std::set<std::string> domains_;
};
onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
switch (at_type) {
case at::kDouble:
return onnx::TensorProto_DataType_DOUBLE;
case at::kFloat:
return onnx::TensorProto_DataType_FLOAT;
case at::kHalf:
return onnx::TensorProto_DataType_FLOAT16;
case at::kByte:
return onnx::TensorProto_DataType_UINT8;
case at::kChar:
return onnx::TensorProto_DataType_INT8;
case at::kShort:
return onnx::TensorProto_DataType_INT16;
case at::kInt:
return onnx::TensorProto_DataType_INT32;
case at::kLong:
return onnx::TensorProto_DataType_INT64;
case at::kBool:
return onnx::TensorProto_DataType_BOOL;
case at::kQInt8:
return onnx::TensorProto_DataType_INT8;
case at::kQUInt8:
return onnx::TensorProto_DataType_UINT8;
case at::kQInt32:
return onnx::TensorProto_DataType_INT32;
default:
AT_ERROR("unexpected tensor scalar type");
}
}
EncoderBase::EncoderBase(
onnx_torch::OperatorExportTypes operator_export_type,
bool strip_doc)
: num_blocks_(0),
operator_export_type_(operator_export_type),
strip_doc_(strip_doc) {
model_proto_.set_producer_name("pytorch");
// we pin IR version to version 4 (01/22/2019) instead of using
// onnx::IR_VERSION. with this change, the test_operators.py will be more
// stable. only bump it when it's necessary
model_proto_.set_ir_version(4);
// TODO: set the producer version using appropriate function call
model_proto_.set_producer_version("1.3");
}
void EncoderBase::EncodeValueInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const Value* n,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
std::string name = n->debugName();
v->set_name(name);
if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
if (!node_type->isComplete()) {
return;
}
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
std::vector<std::int64_t> sizes =
node_type->sizes().concrete_sizes().value();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())){
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
}
else{
shape->mutable_dim(i)->set_dim_value(sizes[i]);
}
}
tensor_type->set_elem_type(
ATenTypeToOnnxType(node_type->scalarType().value()));
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
}
}
void EncoderBase::EncodeGraph(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool keep_initializers_as_inputs) {
EncodeBlock(graph_proto, graph->block(), initializers, dynamic_axes, keep_initializers_as_inputs);
}
void EncoderBase::EncodeBlock(
onnx::GraphProto* graph_proto,
const Block* block,
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool keep_initializers_as_inputs) {
AT_ASSERT(graph_proto != nullptr);
std::string block_name = "torch-jit-export";
if (num_blocks_) {
block_name += std::to_string(num_blocks_);
}
num_blocks_++;
graph_proto->set_name(block_name);
// Since ONNX IR VERSION 4, initializers do not have to
// be a subset of graph inputs. We use keep_initializers_as_inputs
// argument to determine whether to add initializers
// as inputs or not. If keep_initializers_as_inputs=false,
// we only add non-parameter inputs as inputs to ONNX graph, and.
// not the initializers (parameters). If keep_initializers_as_inputs
// =true, we add initializers as inputs too. Setting
// keep_initializers_as_inputs=false allows better
// optimizations, such as constant-folding, on ONNX graphs
// by backends/optimizers.
if (keep_initializers_as_inputs) {
for (auto input : block->inputs()) {
onnx::ValueInfoProto* v = graph_proto->add_input();
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
}
}
else {
for (auto input : block->inputs()) {
auto it = initializers.find(input->debugName());
if (it == initializers.end()) {
onnx::ValueInfoProto* v = graph_proto->add_input();
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
}
}
}
for (auto output : block->outputs()) {
onnx::ValueInfoProto* v = graph_proto->add_output();
EncodeValueInfo(graph_proto, v, output, dynamic_axes);
}
for (auto node : block->nodes()) {
bool is_raw_export =
operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
if (node->mustBeNone() && !is_raw_export) {
// None nodes are used to implement optional inputs. One
// way to "not provide" an optional input is to create an
// Undefined node, and pass its output as that input.
continue;
}
auto p_n = graph_proto->add_node();
if (!strip_doc_) {
p_n->set_doc_string(node->sourceRange().str());
}
for (auto input : node->inputs()) {
if (input->node()->mustBeNone() && !is_raw_export) {
p_n->add_input("");
} else {
p_n->add_input(input->debugName());
}
}
for (auto output : node->outputs()) {
p_n->add_output(output->debugName());
EncodeIntermediateValueInfo(graph_proto, output);
}
if (!node->kind().is_onnx()) {
p_n->set_domain(node->kind().domainString());
domains_.insert(node->kind().domainString());
}
if (is_raw_export) {
AT_ASSERT(!node->kind().is_onnx());
} else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
AT_ASSERT(
!node->kind().is_aten() && !node->kind().is_prim() &&
!node->kind().is_attr());
}
p_n->set_op_type(node->kind().toUnqualString());
for (auto attr_name : node->attributeNames()) {
AddAttribute(p_n, node, attr_name);
}
if (is_raw_export && node->blocks().size() > 0) {
auto blocks = p_n->add_attribute();
blocks->set_name("_blocks");
blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
for (auto block : node->blocks()) {
auto graph = blocks->add_graphs();
EncodeBlock(graph, block, initializers);
}
}
if (node->kind() == ::c10::onnx::Loop) {
AT_ASSERT(node->blocks().size() == 1);
auto body = p_n->add_attribute();
body->set_name("body");
body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto g = body->mutable_g();
EncodeBlock(g, node->blocks()[0]);
}
if (node->kind() == ::c10::onnx::If) {
AT_ASSERT(node->blocks().size() == 2);
auto true_branch = p_n->add_attribute();
true_branch->set_name("then_branch");
true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto true_g = true_branch->mutable_g();
EncodeBlock(true_g, node->blocks()[0]);
auto false_branch = p_n->add_attribute();
false_branch->set_name("else_branch");
false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto false_g = false_branch->mutable_g();
EncodeBlock(false_g, node->blocks()[1]);
}
}
AT_ASSERT(block->inputs().size() >= initializers.size());
for (auto& name_tensor_pair : initializers) {
auto p = graph_proto->add_initializer();
p->set_name(name_tensor_pair.first);
EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
}
}
void EncoderBase::AddAttribute(
onnx::NodeProto* node_proto,
const jit::Node* node,
const jit::Symbol name) {
auto attr = node_proto->add_attribute();
AT_ASSERT(name.is_attr());
attr->set_name(name.toUnqualString());
switch (node->kindOf(name)) {
case AttributeKind::f:
attr->set_f(node->f(name));
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
break;
case AttributeKind::fs:
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
for (auto& v : node->fs(name))
attr->add_floats(v);
break;
case AttributeKind::i:
attr->set_type(onnx::AttributeProto_AttributeType_INT);
attr->set_i(node->i(name));
break;
case AttributeKind::is:
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
for (auto& v : node->is(name))
attr->add_ints(v);
break;
case AttributeKind::s:
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
attr->set_s(node->s(name));
break;
case AttributeKind::ss:
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
for (auto& v : node->ss(name))
attr->add_strings(v);
break;
case AttributeKind::t: {
attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
auto t = attr->mutable_t();
EncodeTensor(t, node->t(name));
} break;
case AttributeKind::ts:
attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
for (auto& v : node->ts(name)) {
auto t = attr->add_tensors();
EncodeTensor(t, v);
}
break;
case AttributeKind::g: {
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto g = attr->mutable_g();
EncodeGraph(g, node->g(name));
} break;
case AttributeKind::gs:
attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
for (auto& v : node->gs(name)) {
auto g = attr->add_graphs();
EncodeGraph(g, v);
}
break;
default:
throw std::runtime_error("unexpected attribute kind");
}
}
class GraphEncoder : public EncoderBase {
public:
GraphEncoder(
const std::shared_ptr<Graph>& graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export,
bool strip_doc,
bool keep_initializers_as_inputs);
RawDataExportMap get_raw_data_export_map() {
return raw_data_export_map_;
}
private:
void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) override;
RawDataExportMap raw_data_export_map_;
bool defer_weight_export_;
};
GraphEncoder::GraphEncoder(
const std::shared_ptr<Graph>& graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export,
bool strip_doc,
bool keep_initializers_as_inputs)
: EncoderBase(operator_export_type, strip_doc),
defer_weight_export_(defer_weight_export) {
if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
validateGraph(graph, operator_export_type);
}
auto* imp = model_proto_.add_opset_import();
// This is the version of ONNX operator set we are targeting
imp->set_version(onnx_opset_version);
EncodeGraph(model_proto_.mutable_graph(), graph, initializers, dynamic_axes, keep_initializers_as_inputs);
for (const std::string& domain : domains_) {
auto* opset = model_proto_.add_opset_import();
opset->set_domain(domain);
opset->set_version(0);
}
}
void GraphEncoder::EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
at::Tensor t;
// CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
// TODO We don't call .cpu() on quantized tensors as it fails when calling
// aten::empty() on quantized tensors beyond certain size. Issue #29435.
if (tensor.is_quantized()) {
t = tensor.contiguous();
}
else {
t = tensor.contiguous().cpu();
}
// Add a buffer to the raw_data_export_map for the caller to dump into an
// external data store. If external_ref is not specified, we instead dump
// the contiguous data into the protobuf itself
if (defer_weight_export_ && external_ref) {
// For now, we use the name of the tensor as the external lookup name to
// avoid ONNX protobuf changes.
AT_ASSERT(external_ref.value() == tensor_proto->name());
AT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
raw_data_export_map_[external_ref.value()] = t;
tensor_proto->set_raw_data("__EXTERNAL");
} else {
AT_ASSERT(t.is_contiguous());
tensor_proto->set_raw_data(std::string(
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
}
}
class ScriptModuleSerializer {
public:
explicit ScriptModuleSerializer(const std::string& filename)
: writer_(filename) {}
explicit ScriptModuleSerializer(
const std::function<size_t(const void *, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize(
const script::Module& module,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
C10_LOG_API_USAGE_ONCE("torch.script.save");
writeExtraFiles(module, extra_files);
// Serialize the model object
writeArchive("data", module._ivalue());
// Then we werialize all code info.
writeCode(module.type());
// The tensor constants from the code are written to a separate archive
// so loading the code does not depend on loading the data
std::vector<IValue> ivalue_constants(
constant_table_.begin(), constant_table_.end());
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
if (bytecode_format) {
writeByteCode(module);
}
}
private:
void writeArchive(const std::string& archive_name, const IValue& value) {
std::vector<char> data;
// Vector to capture the run-time class types during pickling the IValues
std::vector<c10::ClassTypePtr> memorizedClassTypes;
Pickler data_pickle(
[&](const char* buf, size_t size) {
data.insert(data.end(), buf, buf + size);
},
nullptr,
&memorizedClassTypes);
data_pickle.protocol();
data_pickle.pushIValue(value);
data_pickle.stop();
size_t i = 0;
std::string prefix = archive_name + "/";
for (const auto& td : data_pickle.tensorData()) {
std::string fname = prefix + std::to_string(i++);
writer_.writeRecord(fname, td.data(), td.sizeInBytes());
}
std::string fname = archive_name + ".pkl";
writer_.writeRecord(fname, data.data(), data.size());
// serialize all the captured run-time class types
for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) {
convertNamedType(wroteType);
}
}
void writeExtraFiles(
const script::Module& module,
const script::ExtraFilesMap& extra_files) {
// Write out extra files.
for (const auto& kv : extra_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
auto hook = GetExtraFilesHook();
if (hook) {
script::ExtraFilesMap hook_files = hook(module);
for (const auto& kv : hook_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
}
}
void writeCode(const at::NamedTypePtr& root_type) {
class_deps_.push_back(root_type);
for (size_t i = 0; i < class_deps_.size(); ++i) {
// note: convertNameType may extend class_deps_, so re-checking
// .size() is necessary
convertNamedType(class_deps_[i]);
}
// Mapping of filename => src. We need this because multiple clases may go
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
for (auto& item : file_streams_) {
const std::string filename = qualifierToArchivePath(item.key(), "code/");
std::string src = item.value().str();
// Only compress these records if they're not tiny.
// The cpu cost of generating zip datastructs and compressing isn't
// well-spent for very small records.
static constexpr size_t kMinToCompress = 200;
writer_.writeRecord(
filename, src.c_str(), src.size(),
src.size() > kMinToCompress /*compress*/);
// Write out the debug information
std::string debugFilename = filename + ".debug_pkl";
SourceRangePickler source_range_pickler;
auto range_data =
source_range_pickler.pickle(item.value().ranges());
writer_.writeRecord(
debugFilename,
range_data.data(),
range_data.size(),
range_data.size() > kMinToCompress /*compress*/);
}
}
void writeByteCode(const script::Module& module) {
auto methods = module.get_methods();
std::vector<c10::IValue> elements;
for (const auto& method : methods) {
const auto& func = method.function();
auto graph = func.graph()->copy();
Inline(*graph);
torch::jit::Code code(graph);
// Make a copy of opnames. Some of them may be changed for mobile later.
std::vector<c10::OperatorName> opnames;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
if (ins.op == OP) {
auto node = code.instructions_source()[i];
opnames.emplace_back(node->schema().operator_name());
}
}
// instructions
std::vector<IValue> inss;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
TORCH_CHECK(isOpSupportedInMobile(ins.op), toString(ins.op),
" is not supported in mobile module.");
if (ins.op == OP) {
if (opnames[ins.X].name == "prim::ListConstruct" ||
opnames[ins.X].name == "prim::TupleConstruct" ||
opnames[ins.X].name == "prim::TupleUnpack" ||
opnames[ins.X].name == "aten::format") {
auto node = code.instructions_source()[i];
ins.op = OPN;
if (opnames[ins.X].name == "prim::TupleUnpack") {
ins.N = node->outputs().size();
} else {
ins.N = node->inputs().size();
}
if (opnames[ins.X].name == "prim::ListConstruct") {
ListTypePtr lt = node->output()->type()->expect<ListType>();
if (lt->getElementType() == IntType::get()) {
opnames[ins.X].overload_name = "int";
} else if (lt->getElementType() == FloatType::get()) {
opnames[ins.X].overload_name = "float";
} else if (lt->getElementType() == BoolType::get()) {
opnames[ins.X].overload_name = "bool";
} else if (lt->getElementType()->isSubtypeOf(TensorType::get())) {
opnames[ins.X].overload_name = "Tensor";
} else {
opnames[ins.X].overload_name = "generic";
}
} else if (opnames[ins.X].name == "prim::TupleConstruct" &&
node->output()->type()->expect<TupleType>()->name().has_value()) {
AT_WARN("Named tuple is serialized as un-named tuple.");
}
}
}
std::vector<IValue> insv{toString(ins.op), ins.X, ins.N};
inss.emplace_back(c10::ivalue::Tuple::create(std::move(insv)));
}
auto instructions = c10::ivalue::Tuple::create(std::move(inss));
auto named_ins = c10::ivalue::Tuple::create({"instructions", instructions});
// operators
std::vector<IValue> opss;
for (const auto& opname : opnames) {
opss.emplace_back(c10::ivalue::Tuple::create({opname.name, opname.overload_name}));
}
auto operators = c10::ivalue::Tuple::create(std::move(opss));
auto named_ops = c10::ivalue::Tuple::create({"operators", operators});
// constants
auto constants = c10::ivalue::Tuple::create(code.constant_table());
auto named_consts = c10::ivalue::Tuple::create({"constants", constants});
// since the register location is embedded into the bytecode, pass the register size
auto named_regsize = c10::ivalue::Tuple::create({"register_size",
static_cast<int>(code.register_size())});
auto element = c10::ivalue::Tuple::create({named_ins, named_ops, named_consts, named_regsize});
elements.push_back(c10::ivalue::Tuple::create({func.qualname().qualifiedName(), element}));
}
auto telements = c10::ivalue::Tuple::create(std::move(elements));
writeArchive("bytecode", telements);
}
void convertNamedType(const c10::NamedTypePtr& class_type) {
if (converted_types_.count(class_type)) {
return;
}
converted_types_.insert(class_type);
std::string qualifier = class_type->name()->prefix();
PythonPrint* pp = file_streams_.find(qualifier);
if (!pp) {
pp = &file_streams_.insert(
qualifier,
PythonPrint(
constant_table_, class_deps_, /*enforce_importable=*/true));
pp->LEGACY_printOpVersion();
}
pp->printNamedType(class_type);
}
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
// created
OrderedDict<std::string, PythonPrint> file_streams_;
bool bytecode_format_;
};
// Pretty printing for ONNX
constexpr char indent_char = ' ';
constexpr size_t indent_multiplier = 2;
std::string idt(size_t indent) {
return std::string(indent * indent_multiplier, indent_char);
}
std::string nlidt(size_t indent) {
return std::string("\n") + idt(indent);
}
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
stream << "TensorProto shape: [";
for (int i = 0; i < tensor.dims_size(); ++i) {
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
}
stream << "]";
}
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
for (int i = 0; i < shape.dim_size(); ++i) {
auto& dim = shape.dim(i);
if (dim.has_dim_value()) {
stream << dim.dim_value();
} else {
stream << "?";
}
stream << (i == shape.dim_size() - 1 ? "" : " ");
}
}
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
stream << "Tensor dims: ";
dump(tensor_type.shape(), stream);
}
void dump(const onnx::TypeProto& type, std::ostream& stream) {
dump(type.tensor_type(), stream);
}
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
stream << "{name: \"" << value_info.name() << "\", type:";
dump(value_info.type(), stream);
stream << "}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
void dump(
const onnx::AttributeProto& attr,
std::ostream& stream,
size_t indent) {
stream << "{ name: '" << attr.name() << "', type: ";
if (attr.has_f()) {
stream << "float, value: " << attr.f();
} else if (attr.has_i()) {
stream << "int, value: " << attr.i();
} else if (attr.has_s()) {
stream << "string, value: '" << attr.s() << "'";
} else if (attr.has_g()) {
stream << "graph, value:\n";
dump(attr.g(), stream, indent + 1);
stream << nlidt(indent);
} else if (attr.has_t()) {
stream << "tensor, value:";
dump(attr.t(), stream);
} else if (attr.floats_size()) {
stream << "floats, values: [";
for (int i = 0; i < attr.floats_size(); ++i)
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.ints_size()) {
stream << "ints, values: [";
for (int i = 0; i < attr.ints_size(); ++i)
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.strings_size()) {
stream << "strings, values: [";
for (int i = 0; i < attr.strings_size(); ++i)
stream << "'" << attr.strings(i) << "'"
<< (i == attr.strings_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.tensors_size()) {
stream << "tensors, values: [";
for (auto& t : attr.tensors()) {
dump(t, stream);
}
stream << "]";
} else if (attr.graphs_size()) {
stream << "graphs, values: [";
for (auto& g : attr.graphs()) {
dump(g, stream, indent + 1);
}
stream << "]";
} else {
stream << "UNKNOWN";
}
stream << "}";
}
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
for (int i = 0; i < node.input_size(); ++i) {
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
}
stream << "], outputs: [";
for (int i = 0; i < node.output_size(); ++i) {
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
}
stream << "], attributes: [";
for (int i = 0; i < node.attribute_size(); ++i) {
dump(node.attribute(i), stream, indent + 1);
stream << (i == node.attribute_size() - 1 ? "" : ",");
}
stream << "]}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
<< graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
for (int i = 0; i < graph.input_size(); ++i) {
dump(graph.input(i), stream);
stream << (i == graph.input_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent + 1) << "outputs: [";
for (int i = 0; i < graph.output_size(); ++i) {
dump(graph.output(i), stream);
stream << (i == graph.output_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent + 1) << "initializers: [";
for (int i = 0; i < graph.initializer_size(); ++i) {
dump(graph.initializer(i), stream);
stream << (i == graph.initializer_size() - 1 ? "" : ",");
}
stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
for (int i = 0; i < graph.node_size(); ++i) {
dump(graph.node(i), stream, indent + 2);
if (i != graph.node_size() - 1)
stream << "," << nlidt(indent + 2);
}
stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
}
void dump(
const onnx::OperatorSetIdProto& operator_set_id,
std::ostream& stream) {
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
}
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
<< "producer_name: \"" << model.producer_name() << "\""
<< nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
<< nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
if (model.has_graph()) {
stream << nlidt(indent + 1) << "graph:\n";
dump(model.graph(), stream, indent + 2);
}
if (model.opset_import_size()) {
stream << idt(indent + 1) << "opset_import: [";
for (auto& opset_imp : model.opset_import()) {
dump(opset_imp, stream);
}
stream << "],\n";
}
stream << idt(indent) << "}\n";
}
std::string prettyPrint(const onnx::ModelProto& model) {
std::ostringstream ss;
dump(model, ss, 0);
return ss.str();
}
} // namespace
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = hook;
}
std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool google_printer,
bool keep_initializers_as_inputs) {
auto graph_encoder = GraphEncoder(
graph,
onnx_opset_version,
operator_export_type,
initializers,
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>{},
defer_weight_export,
true,
keep_initializers_as_inputs);
if (google_printer) {
return graph_encoder.get_model_proto().DebugString();
}
return prettyPrint(graph_encoder.get_model_proto());
}
// export_raw_ir will export IR ops without turning them into ONNX ops.
// The output will use the ONNX protobuf format, but the ops will not
// conform to the ONNX op specification. Thus, the output will not
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<std::string, std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool strip_doc_string,
bool keep_initializers_as_inputs) {
auto graph_encoder = GraphEncoder(
graph,
onnx_opset_version,
operator_export_type,
initializers,
dynamic_axes,
defer_weight_export,
strip_doc_string,
keep_initializers_as_inputs);
return std::make_tuple(
graph_encoder.get_model_proto().SerializeAsString(),
graph_encoder.get_raw_data_export_map());
}
void ExportModule(
const script::Module& module,
std::ostream& out,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char *>(buf), nbytes);
return !out ? 0 : nbytes;
});
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::string& filename,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(writer_func);
serializer.serialize(module, extra_files, bytecode_format);
}
} // namespace jit
} // namespace torch