- refactoring serialization of ONNX initializers to be name-based (#17420)

Summary:
Currently, serialization of model parameters in ONNX export depends on the order in which they are stored in a container (`list` on Python side and `std::vector` on C++ side). This has worked fine till now, but if we need to do any pass on that graph that mutates the parameter list, then strictly order-based serialization may not work.

This PR is the first in a set to bring in more passes (such as constant folding) related to ONNX export. This PR lays the groundwork by moving the serialization in ONNX export from order-based to name based approach, which is more amenable to some of the passes.

houseroad - As discussed this change uses a map for export, and removes the code from `export.cpp` that relies on the order to compute initializer names.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17420

Differential Revision: D14361993

Pulled By: houseroad

fbshipit-source-id: da93e945d55755c126de06641f35df87d1648cc4
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 4974929..39417ba 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -115,15 +115,23 @@
   }
 
  protected:
+  // Using std::map instead of std::unordered_map for initializers
+  // in EncodeGraph cosntructor so that the order in which initializers
+  // get written to the ONNX graph is always the deterministic and 
+  // predictable. While this is not a ONNX requirement, it is needed 
+  // for testing purposes in tests that use _export_to_pretty_string()
+  // for validating ONNX graphs.
   void EncodeGraph(
       onnx::GraphProto* graph_proto,
       const std::shared_ptr<Graph>& graph,
-      const std::vector<at::Tensor>& initializers = {});
+      const std::map<std::string, at::Tensor>& initializers = 
+        std::map<std::string, at::Tensor>());
 
   void EncodeBlock(
       onnx::GraphProto* graph_proto,
       const Block* block,
-      const std::vector<at::Tensor>& initializers = {});
+      const std::map<std::string, at::Tensor>& initializers = 
+        std::map<std::string, at::Tensor>());
 
   virtual void EncodeTensor(
       onnx::TensorProto* tensor_proto,
@@ -209,14 +217,14 @@
 void EncoderBase::EncodeGraph(
     onnx::GraphProto* graph_proto,
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers) {
+    const std::map<std::string, at::Tensor>& initializers) {
   EncodeBlock(graph_proto, graph->block(), initializers);
 }
 
 void EncoderBase::EncodeBlock(
     onnx::GraphProto* graph_proto,
     const Block* block,
-    const std::vector<at::Tensor>& initializers) {
+    const std::map<std::string, at::Tensor>& initializers) {
   AT_ASSERT(graph_proto != nullptr);
   std::string block_name = "torch-jit-export";
   if (num_blocks_) {
@@ -303,16 +311,11 @@
       EncodeBlock(false_g, node->blocks()[1]);
     }
   }
-  auto num_initializers = initializers.size();
-  AT_ASSERT(block->inputs().size() >= num_initializers);
-  size_t inputs_count = block->inputs().size() - num_initializers;
-  for (auto& tensor : initializers) {
-    // TODO: stop using positions to determine which initializers
-    // match to which inputs
-    std::string name = graph_proto->input(inputs_count++).name();
+  AT_ASSERT(block->inputs().size() >= initializers.size());
+  for (auto& name_tensor_pair : initializers) {
     auto p = graph_proto->add_initializer();
-    p->set_name(name);
-    EncodeTensor(p, tensor, name);
+    p->set_name(name_tensor_pair.first);
+    EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
   }
 }
 
@@ -386,7 +389,7 @@
       const std::shared_ptr<Graph>& graph,
       int64_t onnx_opset_version,
       onnx_torch::OperatorExportTypes operator_export_type,
-      const std::vector<at::Tensor>& initializers,
+      const std::map<std::string, at::Tensor>& initializers,
       bool defer_weight_export,
       bool strip_doc);
 
@@ -408,7 +411,7 @@
     const std::shared_ptr<Graph>& graph,
     int64_t onnx_opset_version,
     onnx_torch::OperatorExportTypes operator_export_type,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     bool defer_weight_export,
     bool strip_doc)
     : EncoderBase(operator_export_type, strip_doc),
@@ -858,7 +861,7 @@
 
 std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type,
@@ -883,7 +886,7 @@
 // libtorch will be able to import the IR and play it back.
 std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type) {
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index ae49245..1904723 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -21,7 +21,7 @@
 
 TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export = false,
     ::torch::onnx::OperatorExportTypes operator_export_type =
@@ -30,7 +30,7 @@
 // For testing purposes
 TORCH_API std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type =
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 04e60a3..1b204aa 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -14,6 +14,7 @@
 
 #include <iostream>
 #include <sstream>
+#include <unordered_map>
 
 namespace torch {
 namespace jit {
@@ -221,7 +222,7 @@
       .def(
           "_export_onnx",
           [](const std::shared_ptr<Graph> g,
-             const std::vector<at::Tensor>& initializers,
+             const std::map<std::string, at::Tensor>& initializers,
              int64_t onnx_opset_version,
              bool defer_weight_export,
              ::torch::onnx::OperatorExportTypes operator_export_type) {
@@ -255,7 +256,7 @@
       .def(
           "_pretty_print_onnx",
           [](const std::shared_ptr<Graph> g,
-             const std::vector<at::Tensor>& initializers,
+             const std::map<std::string, at::Tensor>& initializers,
              int64_t onnx_opset_version,
              bool defer_weight_export,
              ::torch::onnx::OperatorExportTypes operator_export_type,
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 7641aa0..0dea267 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -233,6 +233,10 @@
         graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
         params = list(_unique_state_dict(model).values())
 
+    input_and_param_names = [val.uniqueName() for val in graph.inputs()]
+    param_names = input_and_param_names[len(input_and_param_names) - len(params):]
+    params_dict = dict(zip(param_names, params))
+
     graph = _optimize_graph(graph, operator_export_type)
 
     # NB: ONNX requires complete information about output types, which might be
@@ -246,7 +250,7 @@
     if verbose:
         print(graph)
 
-    return graph, params, torch_out
+    return graph, params_dict, torch_out
 
 
 def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
@@ -294,17 +298,17 @@
     if opset_version is None:
         opset_version = _default_onnx_opset_version
     _set_opset_version(opset_version)
-    graph, params, torch_out = _model_to_graph(model, args, f, verbose,
-                                               training, input_names,
-                                               output_names, operator_export_type,
-                                               example_outputs, propagate)
+    graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
+                                                    training, input_names,
+                                                    output_names, operator_export_type,
+                                                    example_outputs, propagate)
 
     # TODO: Don't allocate a in-memory string for the protobuf
     defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
     if export_params:
-        proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type)
+        proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
     else:
-        proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type)
+        proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
 
     if export_type == ExportTypes.PROTOBUF_FILE:
         assert(len(export_map) == 0)