blob: ec7ba41d79295a590ec0e511508466cf750159fd [file] [log] [blame]
#include "torch/csrc/jit/export.h"
#include "torch/csrc/onnx/onnx.h"
#include "torch/csrc/autograd/symbolic.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/functions/convolution.h"
#include "torch/csrc/utils/functional.h"
#include <ATen/ATen.h>
#include <fstream>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace torch { namespace jit {
namespace {
std::string node_name(Node* n) {
return n->uniqueName();
}
void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const std::vector<at::Tensor> & initializers);
void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor) {
for(auto d : tensor.sizes()) {
p->add_dims(d);
}
at::ScalarType at_type;
onnx::DataType onnx_type;
switch(tensor.type().scalarType()) {
case at::kDouble:
case at::kFloat:
case at::kHalf:
onnx_type = onnx::kFLOAT;
at_type = at::kFloat;
break;
case at::kByte:
case at::kChar:
onnx_type = onnx::kINT8;
at_type = at::kByte;
break;
case at::kShort:
onnx_type = onnx::kINT16;
at_type = at::kShort;
break;
case at::kInt:
onnx_type = onnx::kINT32;
at_type = at::kInt;
break;
case at::kLong:
onnx_type = onnx::kINT64;
at_type = at::kLong;
break;
default:
jit::barf("unexpected tensor scalar type");
break;
}
p->set_data_type(onnx_type);
at::Tensor cont = tensor.toType(at::CPU(at_type)).contiguous();
p->set_raw_data(cont);
}
void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name) {
auto attr = n_p->add_attribute();
attr->set_name(jit::symbolToString(name));
switch(n->kindOf(name)) {
case AttributeKind::f:
attr->set_f(n->f(name));
break;
case AttributeKind::fs:
for(auto & v : n->fs(name))
attr->add_floats(v);
break;
case AttributeKind::i:
attr->set_i(n->i(name));
break;
case AttributeKind::is:
for(auto & v : n->is(name))
attr->add_ints(v);
break;
case AttributeKind::s:
attr->set_s(n->s(name));
break;
case AttributeKind::ss:
for(auto & v : n->ss(name))
attr->add_strings(v);
break;
case AttributeKind::t: {
auto t = attr->mutable_t();
encodeTensor(t, n->t(name));
} break;
case AttributeKind::ts:
for(auto & v : n->ts(name)) {
auto t = attr->add_tensors();
encodeTensor(t, v);
}
break;
case AttributeKind::g: {
auto g = attr->mutable_g();
encodeGraph(g, n->g(name), {});
} break;
case AttributeKind::gs:
for(auto & v : n->gs(name)) {
auto g = attr->add_graphs();
encodeGraph(g, v, {});
}
break;
}
}
void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const std::vector<at::Tensor> & initializers) {
for (auto input : g->inputs()) {
p_g->add_input(node_name(input));
}
for (auto output : g->outputs()) {
p_g->add_output(node_name(output));
}
for (auto node : g->nodes()) {
if (node->kind() == kSelect) {
// No select nodes in ONNX: instead we make use
// of the select invariant
continue;
}
if (node->kind() == kUndefined && node->uses().empty()) {
// Undefined nodes never show up in ONNX; they're just a tool
// to help symbolics do the right thing.
continue;
}
auto p_n = p_g->add_node();
for(auto input : node->inputs()) {
p_n->add_input(node_name(input));
}
for(auto output : node->outputs()) {
p_n->add_output(node_name(output));
}
p_n->set_op_type(symbolToString(node->kind()));
for(auto attr_name : node->attributeNames()) {
addAttribute(p_n, node, attr_name);
}
}
int inputs_count = 0;
for (auto & tensor : initializers) {
// TODO: stop using positions to determine which initializers
// match to which inputs
std::string name = p_g->input(inputs_count++);
auto p = p_g->add_initializer();
p->set_name(name);
encodeTensor(p, tensor);
}
}
void checkGraph(const std::shared_ptr<Graph>& graph) {
for (auto node : graph->nodes()) {
if (node->kind() == kPythonOp || node->kind() == kCppOp) {
#define GET_NAME(T) dynamic_cast<T*>(node)->name()
auto name = node->kind() == kPythonOp ? GET_NAME(PythonOp) : GET_NAME(CppOp);
throw std::runtime_error(std::string("Couldn't export ") + name + " function - "
"maybe it doesn't implement a symbolic definition?");
#undef GET_NAME
}
if (node->kind() == kAddConstant) {
throw std::runtime_error("can't serialize PyTorch-only node AddConstant (not implemented yet)");
}
}
}
}
std::string ExportGraph(const std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor> & initializers) {
checkGraph(graph);
onnx::GraphProto graph_proto;
graph_proto.set_name("torch-jit-export");
// Set up nanopb callbacks and compute the amount of space needed to store
// the resulting protobuf
encodeGraph(&graph_proto, graph, initializers);
size_t out_size;
pb_get_encoded_size(&out_size, onnx_GraphProto_fields, &graph_proto.proto);
// Allocate storage and export the graph
std::string out(out_size, '\0');
pb_ostream_t ostream = pb_ostream_from_buffer(reinterpret_cast<pb_byte_t *>(&out[0]), out_size);
pb_encode(&ostream, onnx_GraphProto_fields, &graph_proto.proto);
return out;
}
}}