Adding support for upper and lower bound functions in SSA
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77389
Approved by: https://github.com/eellison
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index 15f41da..162c14b 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -295,6 +295,8 @@
namespace {
+c10::optional<int64_t> sym_dim = c10::nullopt;
+
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
auto a_canonical = CanonicalizedSymbolicShape(a);
@@ -328,7 +330,6 @@
c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
// Check vector initializer list syntax
- c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss_concrete =
std::vector<c10::optional<int64_t>>{1, 56, 56};
c10::SymbolicShape ss1 = std::vector<c10::optional<int64_t>>{sym_dim, 56, 56};
@@ -361,6 +362,22 @@
assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
}
+TEST(ShapeAnalysisTest, BoundedSymbolicShapes) {
+ auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)");
+
+ // Test that we generate symbolic shapes for the output of a nonzero op
+ c10::IValue const_size_1 = std::vector<int64_t>{5, 10};
+ auto res =
+ calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1});
+ assertShapeEqual(res, {sym_dim, 2});
+
+ // Test that nonzero can also create concrete shapes
+ c10::IValue const_size_2 = std::vector<int64_t>({1, 0});
+ res =
+ calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2});
+ assertShapeEqual(res, {0, 2});
+}
+
TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
clear_shape_cache();
auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
@@ -369,7 +386,6 @@
c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
- c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
@@ -422,7 +438,6 @@
c10::IValue const_int = 1;
- c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
@@ -462,7 +477,6 @@
c10::IValue const_int = 1;
c10::IValue false_ival = false;
- c10::optional<int64_t> sym_dim = c10::nullopt;
c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index 8f7fdc5..7b3a1e0 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -568,6 +568,13 @@
shape_compute_graph_ = (*maybe_graph)->copy();
}
+ SymbolicShapeOpAnalyzer(
+ const FunctionSchema* schema,
+ std::shared_ptr<Graph> graph)
+ : schema_(schema) {
+ shape_compute_graph_ = graph->copy();
+ }
+
c10::optional<std::vector<c10::SymbolicShape>> run(
std::vector<SSArgument>& inputs) {
if (!shape_compute_graph_) {
@@ -747,6 +754,26 @@
return op_analyzer.getShapeComputeGraph();
}
+c10::SymbolicShape combine_bounds(
+ c10::SymbolicShape& lower_bound,
+ c10::SymbolicShape& upper_bound) {
+ // TODO: At some point we might want to add support for dynamic dims
+ TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
+ if (lower_bound.rank() == c10::nullopt) {
+ return c10::SymbolicShape();
+ }
+ std::vector<c10::ShapeSymbol> merged_shapes;
+ for (int i = 0; i < lower_bound.rank(); i++) {
+ // TODO: Merge equivalent expressions (not needed for current use case)
+ if (lower_bound[i] == upper_bound[i]) {
+ merged_shapes.push_back(lower_bound[i]);
+ } else {
+ merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
+ }
+ }
+ return c10::SymbolicShape(merged_shapes);
+}
+
struct SymbolicShapeGraphAnalyzer {
SymbolicShapeGraphAnalyzer(
std::shared_ptr<Graph>& graph,
@@ -1076,7 +1103,9 @@
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs) {
- if (shapeComputeGraphForSchema(*schema) == c10::nullopt) {
+ auto bounded_graphs = boundedGraphsForSchema(*schema);
+ auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt;
+ if (!has_shape_compute && bounded_graphs == c10::nullopt) {
// Avoid doing all this work for functions that don't have a
// supported schema
return c10::nullopt;
@@ -1095,6 +1124,27 @@
ssa_args.emplace_back(ShapeArguments(*ss));
}
}
+ // Handle bounded shape option
+ if (bounded_graphs) {
+ auto lower_bound =
+ SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
+ auto lower_bound_res = lower_bound.run(ssa_args);
+ auto upper_bound =
+ SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
+ auto upper_bound_res = upper_bound.run(ssa_args);
+ // Stitch together the values
+ if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
+ TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
+ auto merged_res = std::vector<c10::SymbolicShape>();
+ for (size_t i = 0; i < lower_bound_res->size(); i++) {
+ merged_res.push_back(
+ combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
+ }
+ cache_shape_function(schema, inputs, merged_res);
+ return merged_res;
+ }
+ return c10::nullopt;
+ }
auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
auto res = op_analyzer.run(ssa_args);
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
index 8b1136b..1c1315a 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
@@ -72,6 +72,9 @@
std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
cached_schema_to_graph;
+std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
+ cached_bounded_schema_to_graph;
+
// CompilationUnit that holds all these Functions and keeps them alive.
auto compilation_unit = std::make_shared<CompilationUnit>();
@@ -237,34 +240,54 @@
}
}
+std::shared_ptr<Graph> genShapeComputeFn(
+ const FunctionSchema* schema_string,
+ const std::string& shape_compute_function_name,
+ std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
+ const CompilationUnit& module) {
+ std::shared_ptr<Graph> graph;
+ if (reused_functions.count(shape_compute_function_name)) {
+ graph = reused_functions[shape_compute_function_name];
+ } else {
+ Function& shape_compute_function =
+ module.get_function(shape_compute_function_name);
+ graph = toGraphFunction(shape_compute_function).graph();
+
+ transformShapeFunction(schema_string, graph);
+ // NB: we lint the shape functions registered in source
+ // in a test file
+ // LintShapeComputeGraph(schema_string, graph);
+
+ reused_functions[shape_compute_function_name] = graph;
+ }
+ // allow extra unused arguments to map multiple functions to e.g. unary
+ TORCH_INTERNAL_ASSERT(
+ graph->inputs().size() <= schema_string->arguments().size());
+ return graph;
+}
+
void registerSchema(
const FunctionSchema* schema_string,
const std::string& shape_compute_function_name,
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
const CompilationUnit& module) {
- if (reused_functions.count(shape_compute_function_name)) {
- auto graph = reused_functions[shape_compute_function_name];
-
- // allow extra unused arguments to map multiple functions to e.g. unary
- TORCH_INTERNAL_ASSERT(
- graph->inputs().size() <= schema_string->arguments().size());
-
- cached_schema_to_graph[schema_string] = graph;
- return;
- }
-
- Function& shape_compute_function =
- module.get_function(shape_compute_function_name);
- std::shared_ptr<Graph> graph =
- toGraphFunction(shape_compute_function).graph();
-
- transformShapeFunction(schema_string, graph);
- // NB: we lint the shape functions registered in source
- // in a test file
- // LintShapeComputeGraph(schema_string, graph);
+ auto graph = genShapeComputeFn(
+ schema_string, shape_compute_function_name, reused_functions, module);
cached_schema_to_graph[schema_string] = graph;
- reused_functions[shape_compute_function_name] = graph;
+}
+
+void registerBoundedSchema(
+ const FunctionSchema* schema_string,
+ const std::string& lower_bound_function_name,
+ const std::string& upper_bound_function_name,
+ std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
+ const CompilationUnit& module) {
+ auto lower_graph = genShapeComputeFn(
+ schema_string, lower_bound_function_name, reused_functions, module);
+ auto upper_graph = genShapeComputeFn(
+ schema_string, upper_bound_function_name, reused_functions, module);
+ cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
}
void loadModule(const CompilationUnit& module) {
@@ -304,6 +327,20 @@
}
}
}
+
+ // Now register the bounded schemas
+ for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
+ const FunctionSchema* schema_string = &pair.first->schema();
+ const std::string& lower_bound_function_name = pair.second.first;
+ const std::string& upper_bound_function_name = pair.second.second;
+
+ registerBoundedSchema(
+ schema_string,
+ lower_bound_function_name,
+ upper_bound_function_name,
+ reused_functions,
+ module);
+ }
}
void loadFunctions() {
@@ -341,6 +378,21 @@
return c10::nullopt;
}
+TORCH_API c10::optional<BoundedShapeGraphs> boundedGraphsForSchema(
+ const FunctionSchema& schema) {
+ std::lock_guard<std::mutex> guard(lock);
+ if (cached_bounded_schema_to_graph.size() == 0) {
+ loadFunctions();
+ }
+ GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
+ auto cache_it = cached_bounded_schema_to_graph.find(&schema);
+ if (cache_it != cached_bounded_schema_to_graph.end()) {
+ return cache_it->second;
+ }
+
+ return c10::nullopt;
+}
+
void RegisterShapeComputeGraphForSchema(
const FunctionSchema& schema,
std::shared_ptr<Graph> g) {
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.h b/torch/csrc/jit/runtime/symbolic_shape_registry.h
index cdca73c..0d92e35 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.h
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.h
@@ -46,6 +46,11 @@
Please file an issue.
*/
+struct BoundedShapeGraphs {
+ std::shared_ptr<Graph> lower_bound;
+ std::shared_ptr<Graph> upper_bound;
+};
+
TORCH_API void RegisterShapeComputeGraphForSchema(
const FunctionSchema& schema,
std::shared_ptr<Graph> g);
@@ -53,6 +58,9 @@
TORCH_API c10::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
const FunctionSchema& schema);
+TORCH_API c10::optional<BoundedShapeGraphs> boundedGraphsForSchema(
+ const FunctionSchema& schema);
+
TORCH_API std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas();
TORCH_API void LintShapeComputeGraph(