- refactoring serialization of ONNX initializers to be name-based (#17420)
Summary:
Currently, serialization of model parameters in ONNX export depends on the order in which they are stored in a container (`list` on Python side and `std::vector` on C++ side). This has worked fine till now, but if we need to do any pass on that graph that mutates the parameter list, then strictly order-based serialization may not work.
This PR is the first in a set to bring in more passes (such as constant folding) related to ONNX export. This PR lays the groundwork by moving the serialization in ONNX export from order-based to name based approach, which is more amenable to some of the passes.
houseroad - As discussed this change uses a map for export, and removes the code from `export.cpp` that relies on the order to compute initializer names.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17420
Differential Revision: D14361993
Pulled By: houseroad
fbshipit-source-id: da93e945d55755c126de06641f35df87d1648cc4
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 4974929..39417ba 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -115,15 +115,23 @@
}
protected:
+ // Using std::map instead of std::unordered_map for initializers
+ // in EncodeGraph cosntructor so that the order in which initializers
+ // get written to the ONNX graph is always the deterministic and
+ // predictable. While this is not a ONNX requirement, it is needed
+ // for testing purposes in tests that use _export_to_pretty_string()
+ // for validating ONNX graphs.
void EncodeGraph(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers = {});
+ const std::map<std::string, at::Tensor>& initializers =
+ std::map<std::string, at::Tensor>());
void EncodeBlock(
onnx::GraphProto* graph_proto,
const Block* block,
- const std::vector<at::Tensor>& initializers = {});
+ const std::map<std::string, at::Tensor>& initializers =
+ std::map<std::string, at::Tensor>());
virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
@@ -209,14 +217,14 @@
void EncoderBase::EncodeGraph(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers) {
+ const std::map<std::string, at::Tensor>& initializers) {
EncodeBlock(graph_proto, graph->block(), initializers);
}
void EncoderBase::EncodeBlock(
onnx::GraphProto* graph_proto,
const Block* block,
- const std::vector<at::Tensor>& initializers) {
+ const std::map<std::string, at::Tensor>& initializers) {
AT_ASSERT(graph_proto != nullptr);
std::string block_name = "torch-jit-export";
if (num_blocks_) {
@@ -303,16 +311,11 @@
EncodeBlock(false_g, node->blocks()[1]);
}
}
- auto num_initializers = initializers.size();
- AT_ASSERT(block->inputs().size() >= num_initializers);
- size_t inputs_count = block->inputs().size() - num_initializers;
- for (auto& tensor : initializers) {
- // TODO: stop using positions to determine which initializers
- // match to which inputs
- std::string name = graph_proto->input(inputs_count++).name();
+ AT_ASSERT(block->inputs().size() >= initializers.size());
+ for (auto& name_tensor_pair : initializers) {
auto p = graph_proto->add_initializer();
- p->set_name(name);
- EncodeTensor(p, tensor, name);
+ p->set_name(name_tensor_pair.first);
+ EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
}
}
@@ -386,7 +389,7 @@
const std::shared_ptr<Graph>& graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
bool defer_weight_export,
bool strip_doc);
@@ -408,7 +411,7 @@
const std::shared_ptr<Graph>& graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
bool defer_weight_export,
bool strip_doc)
: EncoderBase(operator_export_type, strip_doc),
@@ -858,7 +861,7 @@
std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
@@ -883,7 +886,7 @@
// libtorch will be able to import the IR and play it back.
std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index ae49245..1904723 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -21,7 +21,7 @@
TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
@@ -30,7 +30,7 @@
// For testing purposes
TORCH_API std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type =
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 04e60a3..1b204aa 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -14,6 +14,7 @@
#include <iostream>
#include <sstream>
+#include <unordered_map>
namespace torch {
namespace jit {
@@ -221,7 +222,7 @@
.def(
"_export_onnx",
[](const std::shared_ptr<Graph> g,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
@@ -255,7 +256,7 @@
.def(
"_pretty_print_onnx",
[](const std::shared_ptr<Graph> g,
- const std::vector<at::Tensor>& initializers,
+ const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 7641aa0..0dea267 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -233,6 +233,10 @@
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
params = list(_unique_state_dict(model).values())
+ input_and_param_names = [val.uniqueName() for val in graph.inputs()]
+ param_names = input_and_param_names[len(input_and_param_names) - len(params):]
+ params_dict = dict(zip(param_names, params))
+
graph = _optimize_graph(graph, operator_export_type)
# NB: ONNX requires complete information about output types, which might be
@@ -246,7 +250,7 @@
if verbose:
print(graph)
- return graph, params, torch_out
+ return graph, params_dict, torch_out
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
@@ -294,17 +298,17 @@
if opset_version is None:
opset_version = _default_onnx_opset_version
_set_opset_version(opset_version)
- graph, params, torch_out = _model_to_graph(model, args, f, verbose,
- training, input_names,
- output_names, operator_export_type,
- example_outputs, propagate)
+ graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
+ training, input_names,
+ output_names, operator_export_type,
+ example_outputs, propagate)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if export_params:
- proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type)
+ proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
else:
- proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type)
+ proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)