Adding support for upper and lower bound functions in SSA

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77389

Approved by: https://github.com/eellison
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index 15f41da..162c14b 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -295,6 +295,8 @@
 
 namespace {
 
+c10::optional<int64_t> sym_dim = c10::nullopt;
+
 // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
 void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
   auto a_canonical = CanonicalizedSymbolicShape(a);
@@ -328,7 +330,6 @@
   c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
 
   // Check vector initializer list syntax
-  c10::optional<int64_t> sym_dim = c10::nullopt;
   c10::SymbolicShape ss_concrete =
       std::vector<c10::optional<int64_t>>{1, 56, 56};
   c10::SymbolicShape ss1 = std::vector<c10::optional<int64_t>>{sym_dim, 56, 56};
@@ -361,6 +362,22 @@
   assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
 }
 
+TEST(ShapeAnalysisTest, BoundedSymbolicShapes) {
+  auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)");
+
+  // Test that we generate symbolic shapes for the output of a nonzero op
+  c10::IValue const_size_1 = std::vector<int64_t>{5, 10};
+  auto res =
+      calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1});
+  assertShapeEqual(res, {sym_dim, 2});
+
+  // Test that nonzero can also create concrete shapes
+  c10::IValue const_size_2 = std::vector<int64_t>({1, 0});
+  res =
+      calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2});
+  assertShapeEqual(res, {0, 2});
+}
+
 TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
   clear_shape_cache();
   auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
@@ -369,7 +386,6 @@
   c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
   c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
 
-  c10::optional<int64_t> sym_dim = c10::nullopt;
   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
   c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
   c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
@@ -422,7 +438,6 @@
 
   c10::IValue const_int = 1;
 
-  c10::optional<int64_t> sym_dim = c10::nullopt;
   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
 
   auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
@@ -462,7 +477,6 @@
   c10::IValue const_int = 1;
   c10::IValue false_ival = false;
 
-  c10::optional<int64_t> sym_dim = c10::nullopt;
   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
   c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
 
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index 8f7fdc5..7b3a1e0 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -568,6 +568,13 @@
     shape_compute_graph_ = (*maybe_graph)->copy();
   }
 
+  SymbolicShapeOpAnalyzer(
+      const FunctionSchema* schema,
+      std::shared_ptr<Graph> graph)
+      : schema_(schema) {
+    shape_compute_graph_ = graph->copy();
+  }
+
   c10::optional<std::vector<c10::SymbolicShape>> run(
       std::vector<SSArgument>& inputs) {
     if (!shape_compute_graph_) {
@@ -747,6 +754,26 @@
   return op_analyzer.getShapeComputeGraph();
 }
 
+c10::SymbolicShape combine_bounds(
+    c10::SymbolicShape& lower_bound,
+    c10::SymbolicShape& upper_bound) {
+  // TODO: At some point we might want to add support for dynamic dims
+  TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
+  if (lower_bound.rank() == c10::nullopt) {
+    return c10::SymbolicShape();
+  }
+  std::vector<c10::ShapeSymbol> merged_shapes;
+  for (int i = 0; i < lower_bound.rank(); i++) {
+    // TODO: Merge equivalent expressions (not needed for current use case)
+    if (lower_bound[i] == upper_bound[i]) {
+      merged_shapes.push_back(lower_bound[i]);
+    } else {
+      merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
+    }
+  }
+  return c10::SymbolicShape(merged_shapes);
+}
+
 struct SymbolicShapeGraphAnalyzer {
   SymbolicShapeGraphAnalyzer(
       std::shared_ptr<Graph>& graph,
@@ -1076,7 +1103,9 @@
 calculateSymbolicShapesOnOp(
     const FunctionSchema* schema,
     const std::vector<SSAInput>& inputs) {
-  if (shapeComputeGraphForSchema(*schema) == c10::nullopt) {
+  auto bounded_graphs = boundedGraphsForSchema(*schema);
+  auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt;
+  if (!has_shape_compute && bounded_graphs == c10::nullopt) {
     // Avoid doing all this work for functions that don't have a
     // supported schema
     return c10::nullopt;
@@ -1095,6 +1124,27 @@
       ssa_args.emplace_back(ShapeArguments(*ss));
     }
   }
+  // Handle bounded shape option
+  if (bounded_graphs) {
+    auto lower_bound =
+        SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
+    auto lower_bound_res = lower_bound.run(ssa_args);
+    auto upper_bound =
+        SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
+    auto upper_bound_res = upper_bound.run(ssa_args);
+    // Stitch together the values
+    if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
+      TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
+      auto merged_res = std::vector<c10::SymbolicShape>();
+      for (size_t i = 0; i < lower_bound_res->size(); i++) {
+        merged_res.push_back(
+            combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
+      }
+      cache_shape_function(schema, inputs, merged_res);
+      return merged_res;
+    }
+    return c10::nullopt;
+  }
 
   auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
   auto res = op_analyzer.run(ssa_args);
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
index 8b1136b..1c1315a 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
@@ -72,6 +72,9 @@
 std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
     cached_schema_to_graph;
 
+std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
+    cached_bounded_schema_to_graph;
+
 // CompilationUnit that holds all these Functions and keeps them alive.
 auto compilation_unit = std::make_shared<CompilationUnit>();
 
@@ -237,34 +240,54 @@
   }
 }
 
+std::shared_ptr<Graph> genShapeComputeFn(
+    const FunctionSchema* schema_string,
+    const std::string& shape_compute_function_name,
+    std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
+    const CompilationUnit& module) {
+  std::shared_ptr<Graph> graph;
+  if (reused_functions.count(shape_compute_function_name)) {
+    graph = reused_functions[shape_compute_function_name];
+  } else {
+    Function& shape_compute_function =
+        module.get_function(shape_compute_function_name);
+    graph = toGraphFunction(shape_compute_function).graph();
+
+    transformShapeFunction(schema_string, graph);
+    // NB: we lint the shape functions registered in source
+    // in a test file
+    // LintShapeComputeGraph(schema_string, graph);
+
+    reused_functions[shape_compute_function_name] = graph;
+  }
+  // allow extra unused arguments to map multiple functions to e.g. unary
+  TORCH_INTERNAL_ASSERT(
+      graph->inputs().size() <= schema_string->arguments().size());
+  return graph;
+}
+
 void registerSchema(
     const FunctionSchema* schema_string,
     const std::string& shape_compute_function_name,
     std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
     const CompilationUnit& module) {
-  if (reused_functions.count(shape_compute_function_name)) {
-    auto graph = reused_functions[shape_compute_function_name];
-
-    // allow extra unused arguments to map multiple functions to e.g. unary
-    TORCH_INTERNAL_ASSERT(
-        graph->inputs().size() <= schema_string->arguments().size());
-
-    cached_schema_to_graph[schema_string] = graph;
-    return;
-  }
-
-  Function& shape_compute_function =
-      module.get_function(shape_compute_function_name);
-  std::shared_ptr<Graph> graph =
-      toGraphFunction(shape_compute_function).graph();
-
-  transformShapeFunction(schema_string, graph);
-  // NB: we lint the shape functions registered in source
-  // in a test file
-  // LintShapeComputeGraph(schema_string, graph);
+  auto graph = genShapeComputeFn(
+      schema_string, shape_compute_function_name, reused_functions, module);
 
   cached_schema_to_graph[schema_string] = graph;
-  reused_functions[shape_compute_function_name] = graph;
+}
+
+void registerBoundedSchema(
+    const FunctionSchema* schema_string,
+    const std::string& lower_bound_function_name,
+    const std::string& upper_bound_function_name,
+    std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
+    const CompilationUnit& module) {
+  auto lower_graph = genShapeComputeFn(
+      schema_string, lower_bound_function_name, reused_functions, module);
+  auto upper_graph = genShapeComputeFn(
+      schema_string, upper_bound_function_name, reused_functions, module);
+  cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
 }
 
 void loadModule(const CompilationUnit& module) {
@@ -304,6 +327,20 @@
       }
     }
   }
+
+  // Now register the bounded schemas
+  for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
+    const FunctionSchema* schema_string = &pair.first->schema();
+    const std::string& lower_bound_function_name = pair.second.first;
+    const std::string& upper_bound_function_name = pair.second.second;
+
+    registerBoundedSchema(
+        schema_string,
+        lower_bound_function_name,
+        upper_bound_function_name,
+        reused_functions,
+        module);
+  }
 }
 
 void loadFunctions() {
@@ -341,6 +378,21 @@
   return c10::nullopt;
 }
 
+TORCH_API c10::optional<BoundedShapeGraphs> boundedGraphsForSchema(
+    const FunctionSchema& schema) {
+  std::lock_guard<std::mutex> guard(lock);
+  if (cached_bounded_schema_to_graph.size() == 0) {
+    loadFunctions();
+  }
+  GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
+  auto cache_it = cached_bounded_schema_to_graph.find(&schema);
+  if (cache_it != cached_bounded_schema_to_graph.end()) {
+    return cache_it->second;
+  }
+
+  return c10::nullopt;
+}
+
 void RegisterShapeComputeGraphForSchema(
     const FunctionSchema& schema,
     std::shared_ptr<Graph> g) {
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.h b/torch/csrc/jit/runtime/symbolic_shape_registry.h
index cdca73c..0d92e35 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.h
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.h
@@ -46,6 +46,11 @@
 Please file an issue.
 */
 
+struct BoundedShapeGraphs {
+  std::shared_ptr<Graph> lower_bound;
+  std::shared_ptr<Graph> upper_bound;
+};
+
 TORCH_API void RegisterShapeComputeGraphForSchema(
     const FunctionSchema& schema,
     std::shared_ptr<Graph> g);
@@ -53,6 +58,9 @@
 TORCH_API c10::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
     const FunctionSchema& schema);
 
+TORCH_API c10::optional<BoundedShapeGraphs> boundedGraphsForSchema(
+    const FunctionSchema& schema);
+
 TORCH_API std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas();
 
 TORCH_API void LintShapeComputeGraph(