Export modules in ir with google protobuf
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9746
Differential Revision: D9110006
Pulled By: li-roy
fbshipit-source-id: 8b9744c042f822fdfe959a7a7fef3d0baff4f639
diff --git a/test/test_jit.py b/test/test_jit.py
index c05ed16..79a0a68 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3603,6 +3603,139 @@
self.assertEqual(1, foo3(a))
self.assertEqual(2, foo3(b))
+ def test_script_module_export_submodule(self):
+ class M1(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M1, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
+
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
+
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(False)
+ # test submodule
+ self.sub = M1()
+ self.weight = nn.Parameter(torch.randn(2, 3))
+ self.bias = nn.Parameter(torch.randn(2))
+ self.define("""
+ def hi(self, a):
+ return self.weight.mm(a)
+ """)
+
+ @torch.jit.script_method
+ def doit(self, input):
+ return self.weight.mm(input)
+
+ @torch.jit.script_method
+ def doit2(self, input):
+ return self.weight.mm(input)
+
+ @torch.jit.script_method
+ def doit3(self, input):
+ return input + torch.ones([1], dtype=torch.double)
+
+ @torch.jit.script_method
+ def forward(self, input):
+ a = self.doit(input)
+ b = self.doit2(input)
+ c = self.hi(input)
+ return a + b + self.bias + c
+
+ m_orig = M2()
+ m_import = torch.jit.ScriptModule()
+ m_export, storage_map = m_orig.export()
+ torch._C._jit_import_module(m_import, m_export, storage_map)
+
+ input = torch.randn(3, 2)
+ self.assertEqual(m_orig.doit(input), m_import.doit(input))
+ self.assertEqual(m_orig.hi(input), m_import.hi(input))
+ self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
+ self.assertEqual(m_orig.forward(input), m_import.forward(input))
+
+ @skipIfNoTorchVision
+ def test_script_module_export_resnet18(self):
+ x = torch.ones(1, 3, 224, 224)
+ m_orig = torch.jit.trace(torch.ones(1, 3, 224, 224))(torchvision.models.resnet18())
+ m_import = torch.jit.ScriptModule()
+ m_export, storage_map = m_orig.export()
+ torch._C._jit_import_module(m_import, m_export, storage_map)
+
+ input = torch.randn(1, 3, 224, 224, requires_grad=True)
+ output_orig = m_orig(input)
+ output_orig.sum().backward()
+ grad_orig = input.grad.clone()
+ input.grad.zero_()
+
+ output_import = m_import(input)
+ output_import.sum().backward()
+ grad_import = input.grad.clone()
+
+ self.assertEqual(output_orig, output_import)
+ self.assertEqual(grad_orig, grad_import)
+
+ def test_script_module_export_tensor_type(self):
+ class M(torch.jit.ScriptModule):
+
+ def __init__(self, type):
+ super(M, self).__init__(False)
+ self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
+
+ @torch.jit.script_method
+ def foo(self):
+ return self.param
+
+ for type in [torch.float, torch.double]:
+ m_orig = M(type)
+ m_import = torch.jit.ScriptModule()
+ m_export, storage_map = m_orig.export()
+ torch._C._jit_import_module(m_import, m_export, storage_map)
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+
+ @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
+ def test_script_module_export_tensor_cuda(self):
+ class M(torch.jit.ScriptModule):
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda').random_())
+
+ @torch.jit.script_method
+ def foo(self):
+ return self.param
+
+ m_orig = M()
+ m_import = torch.jit.ScriptModule()
+ m_export, storage_map = m_orig.export()
+ torch._C._jit_import_module(m_import, m_export, storage_map)
+ self.assertTrue(m_import.foo().device == torch.device('cpu'))
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+
+ def test_script_module_export_shared_storage(self):
+ class M(torch.jit.ScriptModule):
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.param1 = torch.nn.Parameter(torch.rand(5, 5))
+ self.param2 = torch.nn.Parameter(self.param1[3])
+ self.param3 = torch.nn.Parameter(torch.rand(5, 5))
+
+ @torch.jit.script_method
+ def foo(self):
+ return self.param1 + self.param2 + self.param3
+
+ m_orig = M()
+ m_import = torch.jit.ScriptModule()
+ m_export, storage_map = m_orig.export()
+ torch._C._jit_import_module(m_import, m_export, storage_map)
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
+ self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+
def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 20208af..7174fd0 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -22,317 +22,7 @@
namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
-std::string value_name(Value* n) {
- return n->uniqueName();
-}
-
-struct ExportContext {
- size_t num_blocks = 0;
- onnx_torch::OperatorExportTypes operator_export_type;
-};
-
-void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g,
- const std::vector<at::Tensor> & initializers,
- ExportContext *ctx, RawDataExportMap* raw_data_export_map=nullptr);
-
-void encodeBlock(onnx::GraphProto * p_g, Block *b,
- const std::vector<at::Tensor> & initializers,
- ExportContext *ctx, RawDataExportMap* raw_data_export_map);
-
-void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor,
- at::optional<std::string> external_ref={},
- RawDataExportMap* raw_data_export_map = nullptr) {
- for(auto d : tensor.sizes()) {
- p->add_dims(d);
- }
- onnx::TensorProto_DataType onnx_type;
- // Most integral types and float16 need to be serialized as int32
- at::ScalarType cast_type = tensor.type().scalarType();
- switch(tensor.type().scalarType()) {
- case at::kDouble:
- onnx_type = onnx::TensorProto_DataType_DOUBLE;
- break;
- case at::kFloat:
- onnx_type = onnx::TensorProto_DataType_FLOAT;
- break;
- case at::kHalf:
- onnx_type = onnx::TensorProto_DataType_FLOAT16;
- cast_type = at::kInt;
- break;
- case at::kByte:
- onnx_type = onnx::TensorProto_DataType_UINT8;
- cast_type = at::kInt;
- break;
- case at::kChar:
- onnx_type = onnx::TensorProto_DataType_INT8;
- cast_type = at::kInt;
- break;
- case at::kShort:
- onnx_type = onnx::TensorProto_DataType_INT16;
- cast_type = at::kInt;
- break;
- case at::kInt:
- onnx_type = onnx::TensorProto_DataType_INT32;
- break;
- case at::kLong:
- onnx_type = onnx::TensorProto_DataType_INT64;
- break;
- default:
- AT_ERROR("unexpected tensor scalar type");
- break;
- }
- p->set_data_type(onnx_type);
- // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
- auto t = tensor.contiguous().toBackend(at::kCPU).toType(cast_type);
- // Add a buffer to the raw_data_export_map for the caller to dump into an
- // external data store. If external_ref is not specified, we instead dump
- // the contiguous data into the protobuf itself
- if (external_ref) {
- // For now, we use the name of the tensor as the external lookup name to
- // avoid ONNX protobuf changes.
- JIT_ASSERT(external_ref.value() == p->name());
- JIT_ASSERT(raw_data_export_map != nullptr);
- JIT_ASSERT(raw_data_export_map->count(external_ref.value()) == 0);
- (*raw_data_export_map)[external_ref.value()] = t;
- p->set_raw_data("__EXTERNAL");
- } else {
- JIT_ASSERT(t.is_contiguous());
- p->set_raw_data(std::string(static_cast<char*>(t.data_ptr()), t.type().elementSizeInBytes() * t.numel()));
- }
-}
-
-void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, ExportContext *ctx) {
- auto attr = n_p->add_attribute();
- JIT_ASSERT(name.is_attr());
- attr->set_name(name.toUnqualString());
- switch(n->kindOf(name)) {
- case AttributeKind::f:
- attr->set_f(n->f(name));
- attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
- break;
- case AttributeKind::fs:
- attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
- for(auto & v : n->fs(name))
- attr->add_floats(v);
- break;
- case AttributeKind::i:
- attr->set_type(onnx::AttributeProto_AttributeType_INT);
- attr->set_i(n->i(name));
- break;
- case AttributeKind::is:
- attr->set_type(onnx::AttributeProto_AttributeType_INTS);
- for(auto & v : n->is(name))
- attr->add_ints(v);
- break;
- case AttributeKind::s:
- attr->set_type(onnx::AttributeProto_AttributeType_STRING);
- attr->set_s(n->s(name));
- break;
- case AttributeKind::ss:
- attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
- for(auto & v : n->ss(name))
- attr->add_strings(v);
- break;
- case AttributeKind::t: {
- attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
- auto t = attr->mutable_t();
- encodeTensor(t, n->t(name));
- } break;
- case AttributeKind::ts:
- attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
- for(auto & v : n->ts(name)) {
- auto t = attr->add_tensors();
- encodeTensor(t, v);
- }
- break;
- case AttributeKind::g: {
- attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
- auto g = attr->mutable_g();
- encodeGraph(g, n->g(name), {}, ctx, nullptr);
- } break;
- case AttributeKind::gs:
- attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
- for(auto & v : n->gs(name)) {
- auto g = attr->add_graphs();
- encodeGraph(g, v, {}, ctx, nullptr);
- }
- break;
- }
-}
-
-void encodeTypeProtoTensorType(onnx::TypeProto_Tensor* tensor_type, Value* n) {
- onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
- if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
- const std::vector<std::int64_t>& sizes = node_type->sizes();
- for (size_t i = 0; i < sizes.size(); i++) {
- shape->add_dim();
- shape->mutable_dim(i)->set_dim_value(sizes[i]);
- }
- onnx::TensorProto_DataType onnx_type;
- switch(node_type->scalarType()) {
- case at::kDouble:
- onnx_type = onnx::TensorProto_DataType_DOUBLE;
- break;
- case at::kFloat:
- onnx_type = onnx::TensorProto_DataType_FLOAT;
- break;
- case at::kHalf:
- onnx_type = onnx::TensorProto_DataType_FLOAT16;
- break;
- case at::kByte:
- onnx_type = onnx::TensorProto_DataType_UINT8;
- break;
- case at::kChar:
- onnx_type = onnx::TensorProto_DataType_INT8;
- break;
- case at::kShort:
- onnx_type = onnx::TensorProto_DataType_INT16;
- break;
- case at::kInt:
- onnx_type = onnx::TensorProto_DataType_INT32;
- break;
- case at::kLong:
- onnx_type = onnx::TensorProto_DataType_INT64;
- break;
- default:
- AT_ERROR("unexpected tensor scalar type");
- break;
- }
- tensor_type->set_elem_type(onnx_type);
- }
-}
-
-void encodeValueInfo(onnx::ValueInfoProto* v, Value* n) {
- v->set_name(value_name(n));
- onnx::TypeProto* t = v->mutable_type();
- onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
- encodeTypeProtoTensorType(tensor_type, n);
-}
-
-void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph>& g,
- const std::vector<at::Tensor> & initializers,
- ExportContext *ctx, RawDataExportMap* raw_data_export_map) {
- encodeBlock(p_g, g->block(), initializers, ctx, raw_data_export_map);
-}
-
-void encodeBlock(onnx::GraphProto * p_g, Block *b,
- const std::vector<at::Tensor> & initializers,
- ExportContext *ctx, RawDataExportMap* raw_data_export_map) {
- JIT_ASSERT(p_g != nullptr);
- std::string block_name = "torch-jit-export";
- if (ctx->num_blocks) {
- block_name += std::to_string(ctx->num_blocks);
- }
- ctx->num_blocks++;
- p_g->set_name(block_name);
-
- for (auto input : b->inputs()) {
- onnx::ValueInfoProto* v = p_g->add_input();
- encodeValueInfo(v, input);
- }
- for (auto output : b->outputs()) {
- onnx::ValueInfoProto* v = p_g->add_output();
- encodeValueInfo(v, output);
- }
- for (auto node : b->nodes()) {
- bool is_raw_export = ctx->operator_export_type == onnx_torch::OperatorExportTypes::RAW;
- if (node->kind() == prim::Undefined && !is_raw_export) {
- // Undefined nodes are used to implement optional inputs. One
- // way to "not provide" an optional input is to create an
- // Undefined node, and pass its output as that input.
- continue;
- }
- auto p_n = p_g->add_node();
- if (node->getSourceLocation()) {
- std::stringstream ss;
- node->getSourceLocation()->highlight(ss);
- p_n->set_doc_string(ss.str());
- }
- for(auto input : node->inputs()) {
- if (input->node()->kind() == prim::Undefined && !is_raw_export) {
- p_n->add_input("");
- } else {
- p_n->add_input(value_name(input));
- }
- }
- for(auto output : node->outputs()) {
- p_n->add_output(value_name(output));
- }
- if (is_raw_export) {
- JIT_ASSERT(!node->kind().is_onnx());
- p_n->set_domain(node->kind().domainString());
- }
- else if (ctx->operator_export_type != onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
- JIT_ASSERT(node->kind().is_onnx());
- }
- p_n->set_op_type(node->kind().toUnqualString());
- for(auto attr_name : node->attributeNames()) {
- addAttribute(p_n, node, attr_name, ctx);
- }
- if (is_raw_export && node->blocks().size() > 0) {
- auto blocks = p_n->add_attribute();
- blocks->set_name("_blocks");
- blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
- for (auto block : node->blocks()) {
- auto graph = blocks->add_graphs();
- encodeBlock(graph, block, initializers, ctx, raw_data_export_map);
- }
- }
- if (node->kind() == torch::jit::onnx::Loop) {
- JIT_ASSERT(node->blocks().size() == 1);
-
- auto body = p_n->add_attribute();
- body->set_name("body");
- body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
- auto g = body->mutable_g();
- encodeBlock(g, node->blocks()[0], {}, ctx, raw_data_export_map);
- }
- if (node->kind() == torch::jit::onnx::If) {
- JIT_ASSERT(node->blocks().size() == 2);
-
- auto true_branch = p_n->add_attribute();
- true_branch->set_name("then_branch");
- true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
- auto true_g = true_branch->mutable_g();
- encodeBlock(true_g, node->blocks()[0], {}, ctx, raw_data_export_map);
-
- auto false_branch = p_n->add_attribute();
- false_branch->set_name("else_branch");
- false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
- auto false_g = false_branch->mutable_g();
- encodeBlock(false_g, node->blocks()[1], {}, ctx, raw_data_export_map);
- }
- }
- auto num_initializers = initializers.size();
- JIT_ASSERT(b->inputs().size() >= num_initializers);
- size_t inputs_count = b->inputs().size() - num_initializers;
- for (auto & tensor : initializers) {
- // TODO: stop using positions to determine which initializers
- // match to which inputs
- std::string name = p_g->input(inputs_count++).name();
- auto p = p_g->add_initializer();
- p->set_name(name);
- if (raw_data_export_map) {
- encodeTensor(p, tensor, name, raw_data_export_map);
- } else {
- encodeTensor(p, tensor, {});
- }
- }
-}
-
-void encodeModel(onnx::ModelProto* p_m, const std::shared_ptr<Graph>& g,
- const std::vector<at::Tensor>& initializers,
- RawDataExportMap* raw_data_export_map = nullptr,
- onnx_torch::OperatorExportTypes operator_export_type
- = onnx_torch::OperatorExportTypes::ONNX) {
- onnx::GraphProto* p_g = p_m->mutable_graph();
- ExportContext ctx;
- ctx.operator_export_type = operator_export_type;
- encodeGraph(p_g, g, initializers, &ctx, raw_data_export_map);
-}
-
-namespace {
-std::string getNodeStackTraceString(Node* n) {
+std::string getNodeStackTraceString(const Node* n) {
std::stringstream ss;
if (n->getSourceLocation()) {
n->getSourceLocation()->highlight(ss);
@@ -341,7 +31,6 @@
}
return ss.str();
}
-} // namespace
void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
for (auto node : graph->nodes()) {
@@ -376,6 +65,591 @@
}
}
+class EncoderBase {
+ public:
+ EncoderBase(onnx::ModelProto *model_proto,
+ onnx_torch::OperatorExportTypes operator_export_type,
+ bool defer_weight_export = false);
+
+ RawDataExportMap get_raw_data_export_map() {
+ return raw_data_export_map_;
+ }
+
+ protected:
+ void EncodeGraph(onnx::GraphProto *graph_proto,
+ const std::shared_ptr<Graph> &graph,
+ const std::vector<at::Tensor> &initializers = {});
+
+ void EncodeBlock(onnx::GraphProto *graph_proto,
+ const Block *block,
+ const std::vector<at::Tensor> &initializers = {});
+
+ virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
+ const at::Tensor &tensor,
+ const at::optional<std::string> external_ref = {});
+
+ virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
+ const Value* n) {};
+
+ virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const Value* n);
+
+ void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name);
+
+ size_t num_blocks_;
+ bool defer_weight_export_;
+ onnx_torch::OperatorExportTypes operator_export_type_;
+ RawDataExportMap raw_data_export_map_;
+};
+
+onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
+ switch(at_type) {
+ case at::kDouble:
+ return onnx::TensorProto_DataType_DOUBLE;
+ case at::kFloat:
+ return onnx::TensorProto_DataType_FLOAT;
+ case at::kHalf:
+ return onnx::TensorProto_DataType_FLOAT16;
+ case at::kByte:
+ return onnx::TensorProto_DataType_UINT8;
+ case at::kChar:
+ return onnx::TensorProto_DataType_INT8;
+ case at::kShort:
+ return onnx::TensorProto_DataType_INT16;
+ case at::kInt:
+ return onnx::TensorProto_DataType_INT32;
+ case at::kLong:
+ return onnx::TensorProto_DataType_INT64;
+ default:
+ AT_ERROR("unexpected tensor scalar type");
+ }
+}
+
+EncoderBase::EncoderBase(
+ onnx::ModelProto *model_proto,
+ onnx_torch::OperatorExportTypes operator_export_type,
+ bool defer_weight_export)
+ : num_blocks_(0),
+ defer_weight_export_(defer_weight_export),
+ operator_export_type_(operator_export_type) {
+ model_proto->set_producer_name("pytorch");
+ model_proto->set_ir_version(onnx::IR_VERSION);
+ model_proto->set_producer_version("0.3");
+}
+
+void EncoderBase::EncodeValueInfo(
+ onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const Value* n) {
+ v->set_name(n->uniqueName());
+ onnx::TypeProto* t = v->mutable_type();
+ onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
+
+ onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
+ if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
+ const std::vector<std::int64_t>& sizes = node_type->sizes();
+ for (size_t i = 0; i < sizes.size(); i++) {
+ shape->add_dim();
+ shape->mutable_dim(i)->set_dim_value(sizes[i]);
+ }
+ tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
+ }
+}
+
+void EncoderBase::EncodeGraph(
+ onnx::GraphProto *graph_proto,
+ const std::shared_ptr<Graph> &graph,
+ const std::vector<at::Tensor> &initializers) {
+ EncodeBlock(graph_proto, graph->block(), initializers);
+}
+
+void EncoderBase::EncodeBlock(
+ onnx::GraphProto *graph_proto, const Block *block,
+ const std::vector<at::Tensor> &initializers) {
+ JIT_ASSERT(graph_proto != nullptr);
+ std::string block_name = "torch-jit-export";
+ if (num_blocks_) {
+ block_name += std::to_string(num_blocks_);
+ }
+ num_blocks_++;
+ graph_proto->set_name(block_name);
+
+ for (auto input : block->inputs()) {
+ onnx::ValueInfoProto* v = graph_proto->add_input();
+ EncodeValueInfo(graph_proto, v, input);
+ }
+ for (auto output : block->outputs()) {
+ onnx::ValueInfoProto* v = graph_proto->add_output();
+ EncodeValueInfo(graph_proto, v, output);
+ }
+ for (auto node : block->nodes()) {
+ bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
+ if (node->kind() == prim::Undefined && !is_raw_export) {
+ // Undefined nodes are used to implement optional inputs. One
+ // way to "not provide" an optional input is to create an
+ // Undefined node, and pass its output as that input.
+ continue;
+ }
+ auto p_n = graph_proto->add_node();
+ if (node->getSourceLocation()) {
+ std::stringstream ss;
+ node->getSourceLocation()->highlight(ss);
+ p_n->set_doc_string(ss.str());
+ }
+ for(auto input : node->inputs()) {
+ if (input->node()->kind() == prim::Undefined && !is_raw_export) {
+ p_n->add_input("");
+ } else {
+ p_n->add_input(input->uniqueName());
+ }
+ }
+ for(auto output : node->outputs()) {
+ p_n->add_output(output->uniqueName());
+ EncodeIntermediateValueInfo(graph_proto, output);
+ }
+ if (is_raw_export) {
+ JIT_ASSERT(!node->kind().is_onnx());
+ p_n->set_domain(node->kind().domainString());
+ }
+ else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
+ JIT_ASSERT(node->kind().is_onnx());
+ }
+ p_n->set_op_type(node->kind().toUnqualString());
+ for(auto attr_name : node->attributeNames()) {
+ AddAttribute(p_n, node, attr_name);
+ }
+ if (is_raw_export && node->blocks().size() > 0) {
+ auto blocks = p_n->add_attribute();
+ blocks->set_name("_blocks");
+ blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
+ for (auto block : node->blocks()) {
+ auto graph = blocks->add_graphs();
+ EncodeBlock(graph, block, initializers);
+ }
+ }
+ if (node->kind() == torch::jit::onnx::Loop) {
+ JIT_ASSERT(node->blocks().size() == 1);
+
+ auto body = p_n->add_attribute();
+ body->set_name("body");
+ body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+ auto g = body->mutable_g();
+ EncodeBlock(g, node->blocks()[0]);
+ }
+ if (node->kind() == torch::jit::onnx::If) {
+ JIT_ASSERT(node->blocks().size() == 2);
+
+ auto true_branch = p_n->add_attribute();
+ true_branch->set_name("then_branch");
+ true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+ auto true_g = true_branch->mutable_g();
+ EncodeBlock(true_g, node->blocks()[0]);
+
+ auto false_branch = p_n->add_attribute();
+ false_branch->set_name("else_branch");
+ false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+ auto false_g = false_branch->mutable_g();
+ EncodeBlock(false_g, node->blocks()[1]);
+ }
+ }
+ auto num_initializers = initializers.size();
+ JIT_ASSERT(block->inputs().size() >= num_initializers);
+ size_t inputs_count = block->inputs().size() - num_initializers;
+ for (auto & tensor : initializers) {
+ // TODO: stop using positions to determine which initializers
+ // match to which inputs
+ std::string name = graph_proto->input(inputs_count++).name();
+ auto p = graph_proto->add_initializer();
+ p->set_name(name);
+ EncodeTensor(p, tensor, name);
+ }
+}
+
+void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) {
+ auto attr = node_proto->add_attribute();
+ JIT_ASSERT(name.is_attr());
+ attr->set_name(name.toUnqualString());
+ switch(node->kindOf(name)) {
+ case AttributeKind::f:
+ attr->set_f(node->f(name));
+ attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
+ break;
+ case AttributeKind::fs:
+ attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
+ for(auto & v : node->fs(name))
+ attr->add_floats(v);
+ break;
+ case AttributeKind::i:
+ attr->set_type(onnx::AttributeProto_AttributeType_INT);
+ attr->set_i(node->i(name));
+ break;
+ case AttributeKind::is:
+ attr->set_type(onnx::AttributeProto_AttributeType_INTS);
+ for(auto & v : node->is(name))
+ attr->add_ints(v);
+ break;
+ case AttributeKind::s:
+ attr->set_type(onnx::AttributeProto_AttributeType_STRING);
+ attr->set_s(node->s(name));
+ break;
+ case AttributeKind::ss:
+ attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
+ for(auto & v : node->ss(name))
+ attr->add_strings(v);
+ break;
+ case AttributeKind::t: {
+ attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
+ auto t = attr->mutable_t();
+ EncodeTensor(t, node->t(name));
+ } break;
+ case AttributeKind::ts:
+ attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
+ for(auto & v : node->ts(name)) {
+ auto t = attr->add_tensors();
+ EncodeTensor(t, v);
+ }
+ break;
+ case AttributeKind::g: {
+ attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+ auto g = attr->mutable_g();
+ EncodeGraph(g, node->g(name));
+ } break;
+ case AttributeKind::gs:
+ attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
+ for(auto & v : node->gs(name)) {
+ auto g = attr->add_graphs();
+ EncodeGraph(g, v);
+ }
+ break;
+ default:
+ throw std::runtime_error("unexpected attribute kind");
+ }
+}
+
+void EncoderBase::EncodeTensor(
+ onnx::TensorProto *tensor_proto,
+ const at::Tensor &tensor,
+ const at::optional<std::string> external_ref) {
+ for(auto d : tensor.sizes()) {
+ tensor_proto->add_dims(d);
+ }
+ tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
+ // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
+ auto t = tensor.contiguous().toBackend(at::kCPU);
+ // Add a buffer to the raw_data_export_map for the caller to dump into an
+ // external data store. If external_ref is not specified, we instead dump
+ // the contiguous data into the protobuf itself
+ if (defer_weight_export_ && external_ref) {
+ // For now, we use the name of the tensor as the external lookup name to
+ // avoid ONNX protobuf changes.
+ JIT_ASSERT(external_ref.value() == tensor_proto->name());
+ JIT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
+ raw_data_export_map_[external_ref.value()] = t;
+ tensor_proto->set_raw_data("__EXTERNAL");
+ } else {
+ JIT_ASSERT(t.is_contiguous());
+ tensor_proto->set_raw_data(std::string(static_cast<char*>(t.data_ptr()), t.type().elementSizeInBytes() * t.numel()));
+ }
+}
+
+class GraphEncoder: public EncoderBase {
+ public:
+ GraphEncoder(onnx::ModelProto *model_proto,
+ const std::shared_ptr<Graph> &graph,
+ int64_t onnx_opset_version,
+ onnx_torch::OperatorExportTypes operator_export_type,
+ const std::vector<at::Tensor> &initializers,
+ bool defer_weight_export);
+
+};
+
+GraphEncoder::GraphEncoder(
+ onnx::ModelProto *model_proto,
+ const std::shared_ptr<Graph> &graph,
+ int64_t onnx_opset_version,
+ onnx_torch::OperatorExportTypes operator_export_type,
+ const std::vector<at::Tensor> &initializers,
+ bool defer_weight_export)
+ : EncoderBase(model_proto, operator_export_type, defer_weight_export) {
+ if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
+ validateGraph(graph, operator_export_type);
+ }
+
+ auto* imp = model_proto->add_opset_import();
+ // This is the version of ONNX operator set we are targeting
+ imp->set_version(onnx_opset_version);
+
+ EncodeGraph(model_proto->mutable_graph(), graph, initializers);
+}
+
+class ModuleEncoder: public EncoderBase {
+ public:
+ ModuleEncoder(onnx::ModelProto *model_proto,
+ const std::shared_ptr<script::Module> &module);
+
+ private:
+ void EncodeModule(onnx::GraphProto *graph_proto, const std::shared_ptr<script::Module> &module);
+
+ void EncodeParameters(onnx::GraphProto *graph_proto,
+ const std::shared_ptr<script::Module> &module,
+ const std::string prefix);
+
+ void EncodeParameter(onnx::TensorProto *tensor_proto,
+ const script::NamedParameter ¶meter,
+ const std::string prefix);
+
+ void EncodeMethods(onnx::GraphProto *graph_proto,
+ const std::shared_ptr<script::Module> &module,
+ const std::string prefix);
+
+ void EncodeMethod(onnx::NodeProto *node_proto,
+ const std::unique_ptr<script::Method> &method,
+ const std::string prefix);
+
+ virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
+ const at::Tensor &tensor,
+ const at::optional<std::string> external_ref) override;
+
+ virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
+ const Value* n) override;
+
+ virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const Value* n) override;
+
+ void EncodeTypeInfo(onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const TypePtr& type,
+ const std::string& name);
+
+ // Used to deduplicate tensor storages
+ std::unordered_map<const void*, std::string> storage_dedup_map_;
+
+ // Used to keep track of Parameter names so Methods can refer to them
+ std::unordered_map<at::Tensor*, std::string> parameter_map_;
+
+ // Used to create sequential tensor storage names
+ size_t storage_counter_ = 0;
+
+ // Used to create sequential dummy names for node types
+ size_t type_counter_ = 0;
+};
+
+ModuleEncoder::ModuleEncoder(
+ onnx::ModelProto *model_proto,
+ const std::shared_ptr<script::Module> &module)
+ : EncoderBase(model_proto,
+ onnx_torch::OperatorExportTypes::RAW,
+ /*defer_weight_export*/ true) {
+ model_proto->set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
+ EncodeModule(model_proto->mutable_graph(), module);
+}
+
+void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) {
+ auto v = graph_proto->add_value_info();
+ EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
+}
+
+void ModuleEncoder::EncodeTypeInfo(
+ onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const TypePtr& type,
+ const std::string& name) {
+ v->set_name(name);
+ onnx::TypeProto* type_proto = v->mutable_type();
+ onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type();
+ onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape();
+
+ // Use TypeProto fields to encode types.
+ // denotation stores the type as a string
+ auto kind = type->kind();
+ if (kind == TypeKind::DynamicType) {
+ type_proto->set_denotation("DynamicType");
+ } else if (kind == TypeKind::TensorType) {
+ type_proto->set_denotation("TensorType");
+ TensorTypePtr node_type = type->cast<TensorType>();
+ const std::vector<std::int64_t>& sizes = node_type->sizes();
+
+ // store the sizes and strides in the dims field of TensorShapeProto
+ for (size_t i = 0; i < sizes.size(); i++) {
+ shape_proto->add_dim();
+ shape_proto->mutable_dim(i)->set_dim_value(sizes[i]);
+ }
+ const std::vector<std::int64_t>& strides = node_type->strides();
+ for (size_t i = 0; i < strides.size(); i++) {
+ shape_proto->add_dim();
+ shape_proto->mutable_dim(i)->set_dim_value(strides[i]);
+ }
+ tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
+ } else if (kind == TypeKind::TupleType) {
+ type_proto->set_denotation("TupleType");
+ TupleTypePtr node_type = type->cast<TupleType>();
+ auto elements = node_type->elements();
+
+ // Generate a name for and encode each subtype in the value_info field of the GraphProto.
+ for (size_t i = 0; i < elements.size(); i++) {
+ std::string name = "#" + std::to_string(type_counter_++);
+ shape_proto->add_dim();
+ shape_proto->mutable_dim(i)->set_dim_param(name);
+ onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
+ EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name);
+ }
+ } else if (kind == TypeKind::ListType) {
+ type_proto->set_denotation("ListType");
+ ListTypePtr node_type = type->cast<ListType>();
+
+ // Generate a name for and encode the subtype in the value_info field of the GraphProto.
+ std::string name = "#" + std::to_string(type_counter_++);
+ shape_proto->add_dim();
+ shape_proto->mutable_dim(0)->set_dim_param(name);
+ onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
+ EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
+ } else if (kind == TypeKind::NumberType) {
+ type_proto->set_denotation("NumberType");
+ } else if (kind == TypeKind::FloatType) {
+ type_proto->set_denotation("FloatType");
+ } else if (kind == TypeKind::IntType) {
+ type_proto->set_denotation("IntType");
+ } else if (kind == TypeKind::NoneType) {
+ type_proto->set_denotation("NoneType");
+ }
+ else {
+ throw std::runtime_error("unexpected type kind");
+ }
+}
+
+void ModuleEncoder::EncodeValueInfo(
+ onnx::GraphProto *graph_proto,
+ onnx::ValueInfoProto* v,
+ const Value* n) {
+ EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
+}
+
+void ModuleEncoder::EncodeModule(
+ onnx::GraphProto *graph_proto,
+ const std::shared_ptr<script::Module> &module) {
+ EncodeParameters(graph_proto, module, "");
+ EncodeMethods(graph_proto, module, "");
+}
+
+void ModuleEncoder::EncodeParameters(
+ onnx::GraphProto *graph_proto,
+ const std::shared_ptr<script::Module> &module,
+ const std::string prefix) {
+ // Encode each parameter as a initializer in the proto
+ for (auto ¶meter : module->get_parameters()) {
+ auto tensor_proto = graph_proto->add_initializer();
+ EncodeParameter(tensor_proto, parameter.value, prefix);
+ }
+
+ for (auto &submodule : module->get_modules()) {
+ EncodeParameters(graph_proto, submodule.value.module, prefix + submodule.key + ".");
+ }
+}
+
+void ModuleEncoder::EncodeParameter(
+ onnx::TensorProto *tensor_proto,
+ const script::NamedParameter ¶meter,
+ const std::string prefix) {
+ auto tensor = parameter.slot();
+ // Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar
+ auto name = prefix + parameter.name;
+
+ tensor_proto->set_name(name);
+ parameter_map_[tensor] = name;
+
+ // Parameters have these fields, but tensors do not
+ tensor_proto->add_int64_data(parameter.is_buffer);
+ tensor_proto->add_int64_data(tensor->requires_grad());
+
+ EncodeTensor(tensor_proto, *tensor, name);
+}
+
+void ModuleEncoder::EncodeMethods(
+ onnx::GraphProto *graph_proto,
+ const std::shared_ptr<script::Module> &module,
+ const std::string prefix) {
+ // Encode each parameter as a initializer in the proto
+ for (auto &method : module->get_methods()) {
+ auto node_proto = graph_proto->add_node();
+ EncodeMethod(node_proto, method.value, prefix);
+ }
+
+ for (auto &submodule : module->get_modules()) {
+ EncodeMethods(graph_proto, submodule.value.module, prefix + submodule.key + ".");
+ }
+}
+
+void ModuleEncoder::EncodeMethod(
+ onnx::NodeProto *node_proto,
+ const std::unique_ptr<script::Method> &method,
+ const std::string prefix) {
+ node_proto->set_name(prefix + method->name());
+
+ // Store member_inputs of Method in input
+ for (auto &member_input : method->params()) {
+ auto it = parameter_map_.find(member_input);
+ JIT_ASSERT(it != parameter_map_.end());
+ node_proto->add_input(it->second);
+ }
+
+ auto attr_proto = node_proto->add_attribute();
+ attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+
+ for (auto node : method->graph()->nodes()) {
+ if (node->kind() == prim::PythonOp) {
+ auto py_node = static_cast<torch::jit::PythonOp*>(node);
+ throw std::runtime_error(
+ "Couldn't export Python operator " + py_node->name() +
+ "\n\nDefined at:\n" + getNodeStackTraceString(node));
+ }
+ }
+ EncodeBlock(attr_proto->mutable_g(), method->graph()->block(), {});
+}
+
+void ModuleEncoder::EncodeTensor(
+ onnx::TensorProto *tensor_proto,
+ const at::Tensor &tensor,
+ const at::optional<std::string> external_ref = {}) {
+ for (auto &d : tensor.sizes()) {
+ tensor_proto->add_dims(d);
+ }
+ tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
+
+ tensor_proto->add_int64_data(tensor.storage_offset());
+ for (auto &d : tensor.strides()) {
+ tensor_proto->add_int64_data(d);
+ }
+
+ auto storage_ptr = tensor.storage()->pImpl()->data();
+ auto dedup_it = storage_dedup_map_.find(storage_ptr);
+ if (dedup_it != storage_dedup_map_.end()) {
+ tensor_proto->set_doc_string(dedup_it->second);
+ } else {
+ std::string name;
+ if (external_ref) {
+ name = external_ref.value();
+ } else {
+ name = "$" + std::to_string(storage_counter_++);
+ }
+ tensor_proto->set_doc_string(name);
+ JIT_ASSERT(raw_data_export_map_.count(name) == 0);
+ storage_dedup_map_[storage_ptr] = name;
+
+ // NB: This new tensor is created to support cuda tensors.
+ // Storages can be mutated when converting tensors from cuda to cpu,
+ // and we need a cpu tensor to copy data from.
+ auto t = tensor.type().tensor(
+ *tensor.storage(),
+ /* storageOffset = */ 0,
+ /* size = */ { tensor.numel() },
+ /* strides = */ { 1 })
+ .toBackend(at::kCPU);
+ raw_data_export_map_[name] = t;
+ }
+}
+
// Pretty printing
namespace {
constexpr char indent_char = ' ';
@@ -551,46 +825,8 @@
dump(model, ss, 0);
return ss.str();
}
-
}
-namespace {
-
-RawDataExportMap ToModelProto(
- const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor> & initializers,
- int64_t onnx_opset_version,
- bool defer_weight_export,
- onnx_torch::OperatorExportTypes operator_export_type,
- onnx::ModelProto *model_proto) {
- if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
- validateGraph(graph, operator_export_type);
- }
-
- model_proto->set_producer_name("pytorch");
- model_proto->set_producer_version("0.3");
- model_proto->set_ir_version(onnx::IR_VERSION);
- auto* imp = model_proto->add_opset_import();
- // This is the version of ONNX operator set we are targeting
- imp->set_version(onnx_opset_version);
-
- // Map {external_data_ref -> raw data} for external serialization of weights
- RawDataExportMap raw_data_export_map;
-
- // Set up nanopb callbacks and compute the amount of space needed to store
- // the resulting protobuf
- if (defer_weight_export) {
- encodeModel(model_proto, graph, initializers, &raw_data_export_map, operator_export_type);
- } else {
- encodeModel(model_proto, graph, initializers, nullptr, operator_export_type);
- }
-
- return raw_data_export_map;
-}
-
-} // namespace
-
-
std::string PrettyPrintExportedGraph(
const std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor> & initializers,
@@ -598,10 +834,8 @@
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
::ONNX_NAMESPACE::ModelProto model_proto;
- RawDataExportMap raw_data_export_map;
- raw_data_export_map = ToModelProto(
- graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
- &model_proto);
+ auto graph_encoder = GraphEncoder(
+ &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export);
return prettyPrint(model_proto);
}
@@ -617,11 +851,15 @@
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
::ONNX_NAMESPACE::ModelProto model_proto;
- RawDataExportMap raw_data_export_map;
- raw_data_export_map = ToModelProto(
- graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
- &model_proto);
- return std::make_tuple(model_proto.SerializeAsString(), raw_data_export_map);
+ auto graph_encoder = GraphEncoder(
+ &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export);
+ return std::make_tuple(model_proto.SerializeAsString(), graph_encoder.get_raw_data_export_map());
+}
+
+std::tuple<std::string, RawDataExportMap> ExportModule(const std::shared_ptr<script::Module>& module) {
+ ::ONNX_NAMESPACE::ModelProto model_proto;
+ auto module_encoder = ModuleEncoder(&model_proto, module);
+ return std::make_tuple(model_proto.SerializeAsString(), module_encoder.get_raw_data_export_map());
}
}}
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index d0c6212..9457762 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -1,6 +1,7 @@
#pragma once
#include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/onnx/onnx.h"
namespace torch { namespace jit {
@@ -32,4 +33,7 @@
::torch::onnx::OperatorExportTypes operator_export_type
= ::torch::onnx::OperatorExportTypes::ONNX);
+TORCH_API std::tuple<std::string, RawDataExportMap> ExportModule(
+ const std::shared_ptr<script::Module>& module);
+
}}
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index a453925..6d8a4f1 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -16,44 +16,55 @@
namespace {
-// IR graph construction
-
namespace onnx = ::ONNX_NAMESPACE;
-at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) {
+// IR graph construction
- at::Tensor tensor;
+class DecoderBase {
+ protected:
+ virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
- switch(tensor_proto.data_type()) {
+ void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
+ std::unordered_map<std::string, Value*>& value_map);
+
+ void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
+ std::unordered_map<std::string, Value*>& value_map);
+
+ virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {};
+
+ virtual void buildIntermediateValue(Value* value, const std::string& name) {};
+
+ at::ScalarType onnxTypeToATenType(onnx::TensorProto_DataType tensor_proto);
+
+ virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto);
+};
+
+at::ScalarType DecoderBase::onnxTypeToATenType(onnx::TensorProto_DataType onnx_type) {
+ switch(onnx_type) {
case onnx::TensorProto_DataType_UINT8:
- tensor = at::CPU(at::kByte).tensor();
- break;
+ return at::kByte;
case onnx::TensorProto_DataType_INT8:
- tensor = at::CPU(at::kChar).tensor();
- break;
+ return at::kChar;
case onnx::TensorProto_DataType_INT16:
- tensor = at::CPU(at::kShort).tensor();
- break;
+ return at::kShort;
case onnx::TensorProto_DataType_INT32:
- tensor = at::CPU(at::kInt).tensor();
- break;
+ return at::kInt;
case onnx::TensorProto_DataType_INT64:
- tensor = at::CPU(at::kLong).tensor();
- break;
+ return at::kLong;
case onnx::TensorProto_DataType_FLOAT16:
- tensor = at::CPU(at::kHalf).tensor();
- break;
+ return at::kHalf;
case onnx::TensorProto_DataType_FLOAT:
- tensor = at::CPU(at::kFloat).tensor();
- break;
+ return at::kFloat;
case onnx::TensorProto_DataType_DOUBLE:
- tensor = at::CPU(at::kDouble).tensor();
- break;
+ return at::kDouble;
default:
throw std::runtime_error("Unsupported data type");
}
+}
- std::vector<int64_t> sizes = {tensor_proto.dims().begin(), tensor_proto.dims().end()};
+at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) {
+ at::Tensor tensor = at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor();
+ std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() };
tensor.resize_(sizes);
JIT_ASSERT(
@@ -62,22 +73,19 @@
tensor_proto.raw_data().size());
std::memcpy(tensor.data_ptr(), tensor_proto.raw_data().data(), tensor_proto.raw_data().size());
-
return tensor;
}
-void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
- std::unordered_map<std::string, Value*>& value_map);
-
-void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
- std::unordered_map<std::string, Value*>& value_map) {
+void DecoderBase::buildBlocks(
+ const std::vector<onnx::GraphProto>& graphs_, Node* node,
+ std::unordered_map<std::string, Value*>& value_map) {
for (auto g_ : graphs_) {
auto block = node->addBlock();
buildBlock(g_, block, value_map);
}
}
-std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) {
+std::shared_ptr<Graph> DecoderBase::buildGraph(const onnx::GraphProto& graph_proto) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> value_map;
@@ -86,11 +94,13 @@
return graph;
}
-void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
+void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
std::unordered_map<std::string, Value*>& value_map) {
for (auto & input : graph_proto.input()) {
- value_map[input.name()] = block->addInput();
+ auto value = block->addInput();
+ value_map[input.name()] = value;
+ buildValue(value, input);
}
for (auto & node_ : graph_proto.node()) {
@@ -131,14 +141,18 @@
node->ss_(name, {attr.strings().begin(), attr.strings().end()});
break;
case onnx::AttributeProto_AttributeType_TENSORS:
- node->ts_(name, fmap(attr.tensors(), [](const onnx::TensorProto& t) { return buildTensor(t); }));
+ node->ts_(name, fmap(attr.tensors(), [this](const onnx::TensorProto& t) {
+ return buildTensor(t);
+ }));
break;
case onnx::AttributeProto_AttributeType_GRAPHS:
if (attr.name() == "_blocks") {
buildBlocks({attr.graphs().begin(), attr.graphs().end()}, node, value_map);
}
else {
- node->gs_(name, fmap(attr.graphs(), [](const onnx::GraphProto& g_) { return buildGraph(g_); }));
+ node->gs_(name, fmap(attr.graphs(), [this](const onnx::GraphProto& g_) {
+ return buildGraph(g_);
+ }));
}
break;
}
@@ -151,6 +165,7 @@
for (int i=0; i<node_.output().size(); i++) {
value_map[node_.output(i)] = node->outputs()[i];
+ buildIntermediateValue(node->outputs()[i], node_.output(i));
}
block->appendNode(node);
@@ -158,23 +173,21 @@
for (auto & output : graph_proto.output()) {
Value* v = value_map.at(output.name());
+ buildValue(v, output);
block->registerOutput(v);
}
}
-std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto, std::vector<at::Tensor>& initializers) {
+class GraphDecoder : DecoderBase {
+ public:
+ std::shared_ptr<Graph> decode(const std::string& serialized_graph,
+ std::vector<at::Tensor>& initializers);
- auto graph = buildGraph(graph_proto);
-
- for (auto tensor_ : graph_proto.initializer()) {
- initializers.push_back(buildTensor(tensor_));
- }
-
- return graph;
-}
+ void reconstructOutputTypes(Block *b);
+};
// TODO: this should be removed once we'll be able to serialize value types
-void reconstructOutputTypes(Block *b) {
+void GraphDecoder::reconstructOutputTypes(Block *b) {
for (Node * n : b->nodes()) {
if (n->kind() == prim::Constant) {
switch (n->kindOf(attr::value)) {
@@ -211,18 +224,227 @@
}
}
-} // anonymous namespace
+std::shared_ptr<Graph> GraphDecoder::decode(
+ const std::string& serialized_graph,
+ std::vector<at::Tensor>& initializers) {
+ auto model_proto = onnx::ModelProto();
+ model_proto.ParseFromString(serialized_graph);
+
+ auto graph_proto = model_proto.graph();
+ auto graph = buildGraph(graph_proto);
+ for (auto &tensor_ : graph_proto.initializer()) {
+ initializers.push_back(buildTensor(tensor_));
+ }
+ reconstructOutputTypes(graph->block());
+ return graph;
+}
+
+class ModuleDecoder : DecoderBase {
+ public:
+ std::shared_ptr<script::Module> decode(
+ std::shared_ptr<script::Module> root_module,
+ const std::string& serialized_module,
+ const std::unordered_map<std::string, std::string>& storage_map);
+
+ private:
+ virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
+
+ virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) override;
+
+ TypePtr buildType(const onnx::TypeProto& type_proto);
+
+ virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) override;
+
+ virtual void buildIntermediateValue(Value* value, const std::string& name) override;
+
+ at::Tensor buildParameter(const onnx::TensorProto& tensor_proto);
+
+ at::Tensor buildTensorCommon(const onnx::TensorProto& tensor_proto,
+ const int64_t storage_offset,
+ const std::vector<int64_t>& strides);
+
+ std::pair<std::shared_ptr<script::Module>, std::string> parseFullName(
+ std::shared_ptr<script::Module> root_module,
+ const std::string fullname);
+
+ const std::unordered_map<std::string, std::string> *storage_export_map_;
+ std::unordered_map<std::string, std::shared_ptr<at::Tensor>> storage_map_;
+ std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
+};
+
+std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_proto) {
+ for (auto &subtype : graph_proto.value_info()) {
+ value_type_map_[subtype.name()] = &subtype.type();
+ }
+ return DecoderBase::buildGraph(graph_proto);
+}
+
+TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
+ auto tensortype_proto = type_proto.tensor_type();
+ auto shape_proto = tensortype_proto.shape();
+ auto kind = type_proto.denotation();
+ if (kind == "DynamicType") {
+ return DynamicType::get();
+ } else if (kind == "TensorType") {
+ // TODO: Don't use DynamicType here
+ return DynamicType::get();
+ } else if (kind == "TupleType") {
+ std::vector<TypePtr> elems;
+ for (auto &subkind : shape_proto.dim()) {
+ auto it = value_type_map_.find(subkind.dim_param());
+ JIT_ASSERT(it != value_type_map_.end());
+ elems.push_back(buildType(*it->second));
+ }
+ return TupleType::create(elems);
+ } else if (kind == "ListType") {
+ auto subkind = shape_proto.dim(0);
+ auto it = value_type_map_.find(subkind.dim_param());
+ JIT_ASSERT(it != value_type_map_.end());
+ return ListType::create(buildType(*it->second));
+ } else if (kind == "NumberType") {
+ return NumberType::get();
+ } else if (kind == "FloatType") {
+ return FloatType::get();
+ } else if (kind == "IntType") {
+ return IntType::get();
+ } else if (kind == "NoneType") {
+ return NoneType::get();
+ } else {
+ throw std::runtime_error("unexpected string for type kind");
+ }
+}
+
+void ModuleDecoder::buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {
+ value->setType(buildType(valueinfo_proto.type()));
+}
+
+void ModuleDecoder::buildIntermediateValue(Value* value, const std::string& name) {
+ auto it = value_type_map_.find(name);
+ JIT_ASSERT(it != value_type_map_.end());
+ value->setType(buildType(*it->second));
+}
+
+at::Tensor ModuleDecoder::buildParameter(const onnx::TensorProto& tensor_proto) {
+ std::vector<int64_t> strides;
+ // We've stored three other values (is_buffer, requires_grad, storage_offset) before strides; ignore them
+ std::move(tensor_proto.int64_data().begin() + 3, tensor_proto.int64_data().end(), std::back_inserter(strides));
+ auto tensor = buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(2), strides);
+ autograd::Variable var = autograd::make_variable(tensor, /* requires_grad = */ tensor_proto.int64_data(1));
+ return var;
+}
+
+at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
+ std::vector<int64_t> strides;
+ // We've stored one other value (storage_offset) before strides; ignore it
+ std::move(tensor_proto.int64_data().begin() + 1, tensor_proto.int64_data().end(), std::back_inserter(strides));
+ return buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(0), strides);
+}
+
+at::Tensor ModuleDecoder::buildTensorCommon(
+ const onnx::TensorProto& tensor_proto,
+ const int64_t storage_offset,
+ const std::vector<int64_t>& strides) {
+ // NB: storage_offset and strides are passed in separately because
+ // because they are encoded differently for parameters and tensors
+ auto storage_name = tensor_proto.doc_string();
+ auto type = onnxTypeToATenType(tensor_proto.data_type());
+ std::vector<int64_t> dims;
+ std::move(tensor_proto.dims().begin(), tensor_proto.dims().end(), std::back_inserter(dims));
+
+ // Find or create the storage
+ at::Tensor *storage_tensor;
+ auto storage_it = storage_map_.find(storage_name);
+ if (storage_it == storage_map_.end()) {
+ auto storage = std::make_shared<at::Tensor>(at::CPU(type).tensor());
+ auto string_it = storage_export_map_->find(storage_name);
+ JIT_ASSERT(string_it != storage_export_map_->end());
+ storage->resize_({ static_cast<int64_t>(string_it->second.size()) });
+ std::memcpy(storage->storage()->pImpl()->data(), string_it->second.data(), string_it->second.size());
+ storage_map_.insert(std::make_pair(storage_name, storage));
+ storage_tensor = storage.get();
+ } else {
+ storage_tensor = storage_it->second.get();
+ }
+
+ return at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor(
+ *storage_tensor->storage().get(), storage_offset, dims, strides);
+}
+
+// Given a full name of a parameter or method,
+// return the parent submodule and local name
+std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
+ std::shared_ptr<script::Module> root_module,
+ const std::string fullname) {
+ std::vector<std::string> vec;
+ std::stringstream ss(fullname);
+ std::string name;
+ while (std::getline(ss, name, '.')) {
+ vec.push_back(name);
+ }
+
+ std::shared_ptr<script::Module> curr = root_module;
+ for (size_t i = 0; i < vec.size() - 1; i++) {
+ if (curr->find_module(vec[i]) == nullptr) {
+ curr->register_module(vec[i], std::make_shared<script::Module>());
+ }
+ curr = curr->get_module(vec[i]);
+ }
+ return std::make_pair(curr, vec.back());
+}
+
+std::shared_ptr<script::Module> ModuleDecoder::decode(
+ const std::shared_ptr<script::Module> root_module,
+ const std::string &serialized_module,
+ const std::unordered_map<std::string, std::string> &storage_export_map) {
+ auto model_proto = onnx::ModelProto();
+ model_proto.ParseFromString(serialized_module);
+ auto graph_proto = model_proto.graph();
+
+ std::unordered_map<std::string, at::Tensor*> param_map;
+ storage_export_map_ = &storage_export_map;
+ storage_map_.clear();
+
+ for (auto &tensor_proto : graph_proto.initializer()) {
+ std::shared_ptr<script::Module> parent_module;
+ std::string name;
+ std::tie(parent_module, name) = parseFullName(root_module, tensor_proto.name());
+
+ auto param = buildParameter(tensor_proto);
+ parent_module->register_parameter(name, param, /* is_buffer = */ tensor_proto.int64_data(1));
+ param_map[tensor_proto.name()] = parent_module->parameter_slot(name);
+ }
+
+ for (auto &node_proto : graph_proto.node()) {
+ std::shared_ptr<script::Module> parent_module;
+ std::string name;
+ std::tie(parent_module, name) = parseFullName(root_module, node_proto.name());
+
+ std::vector<at::Tensor*> member_inputs;
+ for (auto ¶m_name : node_proto.input()) {
+ member_inputs.push_back(param_map[param_name]);
+ }
+
+ auto graph = buildGraph(node_proto.attribute(0).g());
+ parent_module->create_method(name, graph, member_inputs);
+ }
+
+ return root_module;
+}
+
+} // namespace
std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph,
std::vector<at::Tensor>& initializers) {
- auto model_proto = ::ONNX_NAMESPACE::ModelProto();
- model_proto.ParseFromString(serialized_graph);
+ GraphDecoder decoder;
+ return decoder.decode(serialized_graph, initializers);
+}
- auto graph = buildGraph(model_proto.graph(), initializers);
-
- reconstructOutputTypes(graph->block());
-
- return graph;
+void ImportIRModule(
+ const std::shared_ptr<script::Module> module,
+ const std::string& serialized_module,
+ const std::unordered_map<std::string, std::string>& storage_map) {
+ ModuleDecoder decoder;
+ decoder.decode(module, serialized_module, storage_map);
}
}}
diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h
index d593896f..56606cf 100644
--- a/torch/csrc/jit/import.h
+++ b/torch/csrc/jit/import.h
@@ -1,9 +1,15 @@
#pragma once
#include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/script/module.h"
namespace torch { namespace jit {
TORCH_API std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph, std::vector<at::Tensor> & initializers);
+TORCH_API void ImportIRModule(
+ const std::shared_ptr<script::Module> module,
+ const std::string& serialized_module,
+ const std::unordered_map<std::string, std::string>& storage_map);
+
}}
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index b72fdb6..59884ed 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -496,5 +496,10 @@
}
return std::make_tuple(graph, variables);
});
+ m.def("_jit_import_module", [](const std::shared_ptr<script::Module> module,
+ const std::string& serialized_module,
+ const std::unordered_map<std::string, std::string>& storages) {
+ ImportIRModule(module, serialized_module, storages);
+ });
}
}}
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index b91d348..1813327 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -3,6 +3,7 @@
#include "torch/csrc/Device.h"
#include "torch/csrc/Dtype.h"
#include "torch/csrc/Layout.h"
+#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/python_tracer.h"
@@ -431,6 +432,21 @@
// public.
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
.def(py::init<>())
+ .def("export", [](const std::shared_ptr<Module> m) {
+ std::string module;
+ RawDataExportMap export_map;
+ std::tie(module, export_map) = ExportModule(m);
+ std::unordered_map<std::string, py::bytes> python_serialized_export_map;
+ for (auto& kv : export_map) {
+ auto t = kv.second;
+ size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
+ // TODO: this is an unecessary copy. In theory we can directly return
+ // the map from identifier to Tensor, but we need some API in Python
+ // to get raw `bytes` containing the raw tensor data.
+ python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
+ }
+ return std::make_tuple(py::bytes(module), python_serialized_export_map);
+ })
.def("_set_optimized", &Module::set_optimized)
.def(
"_define",