[onnx.export] Avoid linear loop over symbol_dim_map (#123029)

This PR is part of an effort to speed up torch.onnx.export (#121422).

- Doing a reverse look-up in `symbol_dim_map` incurs a linear cost in number of symbols. This happens for each node, so incurs a quadratic cost to the whole export.
- Add a reverse look-up `dim_symbol_map` that is kept in parallel of `symbol_dim_map`. This avoids a linear time look-up, which creates a quadratic export time complexity.
- This is a highly pragmatic solution. If someone more familiar with the code base has a better solution, I'm interested to hear about it.
- Resolves (9) in #121422.

(partial fix of #121422)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123029
Approved by: https://github.com/justinchuby
diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp
index 716232c..8fd1bed 100644
--- a/torch/csrc/jit/passes/onnx/constant_map.cpp
+++ b/torch/csrc/jit/passes/onnx/constant_map.cpp
@@ -227,6 +227,10 @@
   return ConstantValueMap::getInstance().symbolDimMap;
 }
 
+DimSymbolMap& ConstantValueMap::GetDimSymbolMap() {
+  return ConstantValueMap::getInstance().dimSymbolMap;
+}
+
 template <typename Map>
 void UpdateStrKey(
     Map& map,
@@ -271,6 +275,7 @@
   ConstantValueMap::getInstance().shapeValueMap.clear();
   ConstantValueMap::getInstance().inferredShapeData.clear();
   ConstantValueMap::getInstance().symbolDimMap.clear();
+  ConstantValueMap::getInstance().dimSymbolMap.clear();
   ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt;
 }
 
@@ -359,6 +364,15 @@
       std::cout << std::endl;
     }
   }
+  std::cout << "DimSymbol Map:" << std::endl;
+  count = 0;
+  for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
+    std::cout << "(" << x.first << ": " << x.second << "), ";
+    count++;
+    if (count % 10 == 0) {
+      std::cout << std::endl;
+    }
+  }
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h
index fe33183..303d373 100644
--- a/torch/csrc/jit/passes/onnx/constant_map.h
+++ b/torch/csrc/jit/passes/onnx/constant_map.h
@@ -70,6 +70,7 @@
   static ShapeDataMap& GetInferredShapeData();
 
   static SymbolDimMap& GetSymbolDimMap();
+  static DimSymbolMap& GetDimSymbolMap();
 
   static void UpdateValueName(
       const std::string& old_name,
@@ -104,6 +105,7 @@
   // during future node-level shape inference.
   ShapeDataMap inferredShapeData;
   SymbolDimMap symbolDimMap;
+  DimSymbolMap dimSymbolMap;
   // Stores if all graph-level inputs have static shape
   c10::optional<bool> allGraphInputsStatic;
 };
diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
index eefa962..dd79754 100644
--- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
+++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
@@ -87,9 +87,14 @@
 namespace onnx = ::ONNX_NAMESPACE;
 namespace diagnostics = ::torch::onnx::diagnostics;
 
+// SymbolDimMap is a Torch-to-ONNX shape look-up. This is built so it can be
+// returned by the export function. During the export however, when we come
+// across new ONNX shapes, the reverse look-up is needed. To avoid incurring
+// a linear-time look-up, we maintain DimSymbolMap in parallel.
 c10::ShapeSymbol ONNXDimToShapeSymbol(
     const onnx::TensorShapeProto_Dimension& dim,
-    SymbolDimMap& symbol_dim_map) {
+    SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map) {
   if (dim.has_dim_value()) {
     return c10::ShapeSymbol::fromStaticSize(dim.dim_value());
   }
@@ -97,11 +102,9 @@
   if (dim.has_dim_param()) {
     // If this param is already known, assign the same Symbol.
     GRAPH_UPDATE("Got dim_param:", dim.dim_param());
-    for (const auto& pair : symbol_dim_map) {
-      if (pair.second == dim.dim_param()) {
-        sym = pair.first;
-        break;
-      }
+    auto maybe_symbol = dim_symbol_map.find(dim.dim_param());
+    if (maybe_symbol != dim_symbol_map.end()) {
+      sym = maybe_symbol->second;
     }
   }
   if (!sym) {
@@ -109,13 +112,15 @@
     // If dim.dim_param() is empty, no need to keep track
     // because there won't be duplicates.
     symbol_dim_map[sym.value()] = dim.dim_param();
+    dim_symbol_map[dim.dim_param()] = sym.value();
   }
   return sym.value();
 }
 
 TensorTypePtr TorchTensorTypeFromONNX(
     const onnx::TypeProto_Tensor& onnx_tensor_type,
-    SymbolDimMap& symbol_dim_map) {
+    SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map) {
   std::optional<at::ScalarType> scalar_type;
   if (onnx_tensor_type.has_elem_type()) {
     scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
@@ -132,8 +137,8 @@
     const auto& onnx_shape = onnx_tensor_type.shape();
 
     for (const auto i : c10::irange(onnx_shape.dim_size())) {
-      sizes.emplace_back(
-          ONNXDimToShapeSymbol(onnx_shape.dim(i), symbol_dim_map));
+      sizes.emplace_back(ONNXDimToShapeSymbol(
+          onnx_shape.dim(i), symbol_dim_map, dim_symbol_map));
     }
     v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
     v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
@@ -150,13 +155,14 @@
 
 ListTypePtr TorchListTypeFromONNX(
     const onnx::TypeProto_Sequence& onnx_sequence_type,
-    SymbolDimMap& symbol_dim_map) {
+    SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map) {
   if (onnx_sequence_type.has_elem_type()) {
     const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type();
     if (onnx_seq_elem_type.has_tensor_type()) {
       const auto& onnx_tensor_type = onnx_seq_elem_type.tensor_type();
-      const auto v_tensor_type =
-          TorchTensorTypeFromONNX(onnx_tensor_type, symbol_dim_map);
+      const auto v_tensor_type = TorchTensorTypeFromONNX(
+          onnx_tensor_type, symbol_dim_map, dim_symbol_map);
       auto v_type = ListType::create(v_tensor_type);
       return v_type;
     }
@@ -167,21 +173,22 @@
 void UpdateTorchValueByOnnxValueInfo(
     Value* v,
     const onnx::ValueInfoProto& p_info,
-    SymbolDimMap& symbol_dim_map) {
+    SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map) {
   if (!p_info.has_type()) {
     return;
   }
 
   const auto& p_type = p_info.type();
   if (p_type.has_tensor_type()) {
-    const auto torch_tensor_type =
-        TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_dim_map);
+    const auto torch_tensor_type = TorchTensorTypeFromONNX(
+        p_type.tensor_type(), symbol_dim_map, dim_symbol_map);
     if (torch_tensor_type) {
       MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
     }
   } else if (p_type.has_sequence_type()) {
-    const auto torch_list_type =
-        TorchListTypeFromONNX(p_type.sequence_type(), symbol_dim_map);
+    const auto torch_list_type = TorchListTypeFromONNX(
+        p_type.sequence_type(), symbol_dim_map, dim_symbol_map);
     if (torch_list_type) {
       MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
     }
@@ -377,6 +384,7 @@
     std::shared_ptr<Graph> graph,
     std::shared_ptr<onnx::ModelProto>& model_proto,
     SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map,
     int opset_version) {
   RawDataExportMap export_map;
   bool val_use_external_data_format;
@@ -402,6 +410,9 @@
           false,
           std::string());
   symbol_dim_map.insert(new_symbol_dim_map.begin(), new_symbol_dim_map.end());
+  for (const auto& pair : new_symbol_dim_map) {
+    dim_symbol_map[pair.second] = pair.first;
+  }
   for (int i = 0; i < model_proto->graph().output_size(); ++i) {
     model_proto->mutable_graph()->mutable_output(i)->clear_type();
   }
@@ -1796,7 +1807,8 @@
     Node* n,
     Node* clone_node,
     const onnx::ModelProto& model_proto,
-    SymbolDimMap& symbol_dim_map) {
+    SymbolDimMap& symbol_dim_map,
+    DimSymbolMap& dim_symbol_map) {
   const auto& graph_proto = model_proto.graph();
 
   // get data from value_info and updated original graph.
@@ -1805,7 +1817,7 @@
         for (size_t i = 0; i < n->outputs().size(); ++i) {
           if (clone_node->output(i)->debugName() == v_info.name()) {
             UpdateTorchValueByOnnxValueInfo(
-                n->output(i), v_info, symbol_dim_map);
+                n->output(i), v_info, symbol_dim_map, dim_symbol_map);
           }
         }
       };
@@ -2040,6 +2052,7 @@
   auto& original_shape_data = ConstantValueMap::GetInferredShapeData();
   ShapeDataMap inferred_shape_data;
   auto& symbol_dim_map = ConstantValueMap::GetSymbolDimMap();
+  auto& dim_symbol_map = ConstantValueMap::GetDimSymbolMap();
 
   SetGraphInputTypeReliable(n->owningGraph());
   GRAPH_UPDATE(
@@ -2094,7 +2107,7 @@
       //       e.g: ListConstruct, ListUnpack, etc.
       std::shared_ptr<onnx::ModelProto> model_proto;
       ConvertGraphToONNXProto(
-          n_graph, model_proto, symbol_dim_map, opset_version);
+          n_graph, model_proto, symbol_dim_map, dim_symbol_map, opset_version);
       GRAPH_DEBUG(
           "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
 
@@ -2119,7 +2132,7 @@
           }
         }
         UpdateOutputTypeByONNXProto(
-            n, clone_node, *model_proto, symbol_dim_map);
+            n, clone_node, *model_proto, symbol_dim_map, dim_symbol_map);
       } catch (std::runtime_error& ex) {
         // TODO: include this as warning once we have a more consolidated
         // warning system.
@@ -2161,8 +2174,8 @@
       int rank = inferred_shape.dim_size();
       std::vector<::c10::ShapeSymbol> final_shape(rank);
       for (int i = 0; i < rank; ++i) {
-        final_shape[i] =
-            ONNXDimToShapeSymbol(inferred_shape.dim(i), symbol_dim_map);
+        final_shape[i] = ONNXDimToShapeSymbol(
+            inferred_shape.dim(i), symbol_dim_map, dim_symbol_map);
       }
       c10::SymbolicShape shape_value(final_shape);
       // Store data propagation result into shapeValueMap
diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h
index 3a56cfc..9a7ab2c 100644
--- a/torch/csrc/jit/serialization/export.h
+++ b/torch/csrc/jit/serialization/export.h
@@ -30,6 +30,7 @@
 using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
 
 using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
+using DimSymbolMap = std::map<std::string, c10::ShapeSymbol>;
 
 using NodeNameMap = std::unordered_map<const Node*, std::string>;