Relanding shape cache (75400) (#75710)
Summary:
https://github.com/pytorch/pytorch/pull/75400
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75710
Reviewed By: malfet
Differential Revision: D35598920
Pulled By: Krovatkin
fbshipit-source-id: 2bbbb3d0c24214b5dbb4ca605e7daa94671f96b0
(cherry picked from commit 572f2f9df5bfd73cd7b83536f619bc86d820ccd8)
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 4956ad4..c3d3d07 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -467,6 +467,14 @@
// result will be unranked.
SymbolicShape merge(const SymbolicShape& other) const;
+ friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
+ return lhs.dims_ == rhs.dims_;
+ }
+
+ friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
+ return !(lhs == rhs);
+ }
+
private:
c10::optional<std::vector<ShapeSymbol>> dims_;
};
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index 38bbeaa..15f41da 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -8,6 +8,8 @@
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
@@ -293,27 +295,34 @@
namespace {
+// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
+void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
+ auto a_canonical = CanonicalizedSymbolicShape(a);
+ auto e_canonical = CanonicalizedSymbolicShape(e);
+ EXPECT_EQ(a_canonical, e_canonical);
+}
+
void assertShapeEqual(
c10::optional<std::vector<c10::SymbolicShape>>& actual,
std::vector<c10::optional<int64_t>> expected) {
ASSERT_TRUE(actual.has_value());
ASSERT_EQ(actual->size(), 1);
- auto a_canonical = CanonicalizedSymbolicShape(actual->at(0));
auto symb_expected = c10::SymbolicShape(expected);
- auto b_canonical = CanonicalizedSymbolicShape(symb_expected);
- ASSERT_EQ(a_canonical, b_canonical);
+ assertShapeEqual(actual->at(0), symb_expected);
}
+const FunctionSchema* getSchema(const char* name) {
+ return &(getOperatorForLiteral(name)->schema());
+}
} // namespace
TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
// Figure out how to fetch a function schema
// Ask someone else how to create a function schema / operator in C++
- std::shared_ptr<Operator> op = getOperatorForLiteral(
+ auto schema = getSchema(
"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
- const FunctionSchema* schema = &(op->schema());
c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
@@ -352,5 +361,123 @@
assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
}
+TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
+ clear_shape_cache();
+ auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
+
+ c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
+ c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
+ c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
+
+ c10::optional<int64_t> sym_dim = c10::nullopt;
+ c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
+ c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
+ c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
+
+ auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
+ assertShapeEqual(res, {sym_dim, 56});
+ auto res1_val = res->at(0);
+
+ // The exact same arguments should return the exact same result
+ res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
+ auto res2_val = res->at(0);
+ EXPECT_EQ(res1_val, res2_val);
+ EXPECT_EQ(get_shape_cache_size(), 1);
+
+ // Same shape but different symbols should return same shape
+ // but different symbolic indicies
+ res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
+ auto res3_val = res->at(0);
+
+ assertShapeEqual(res3_val, res2_val);
+ EXPECT_NE(res3_val, res2_val);
+ EXPECT_EQ(get_shape_cache_size(), 1);
+
+ // Different concrete shape should be cached separately
+ res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
+ assertShapeEqual(res, {sym_dim, 20});
+ EXPECT_EQ(get_shape_cache_size(), 2);
+
+ res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
+ assertShapeEqual(res, {sym_dim, 20});
+ EXPECT_EQ(get_shape_cache_size(), 3);
+
+ res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
+ assertShapeEqual(res, {sym_dim, sym_dim});
+ EXPECT_EQ(get_shape_cache_size(), 4);
+}
+
+TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
+ clear_shape_cache();
+
+ auto squeeze_op =
+ getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
+ auto mul_tensor =
+ getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
+ auto mul_scalar =
+ getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
+ auto div_tensor =
+ getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
+ auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
+
+ c10::IValue const_int = 1;
+
+ c10::optional<int64_t> sym_dim = c10::nullopt;
+ c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
+
+ auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
+ assertShapeEqual(res, {sym_dim, 64});
+
+ // Show that cache can handle multiple functions
+ res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
+ assertShapeEqual(res, {sym_dim, 64});
+ EXPECT_EQ(get_shape_cache_size(), 2);
+
+ res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
+ assertShapeEqual(res, {sym_dim, 64});
+ EXPECT_EQ(get_shape_cache_size(), 3);
+
+ // Even when the expected outcome is the same, should not collide
+ res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
+ assertShapeEqual(res, {sym_dim, 64});
+ EXPECT_EQ(get_shape_cache_size(), 4);
+
+ // Don't lose cached objects
+ res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
+ assertShapeEqual(res, {sym_dim, 64});
+ EXPECT_EQ(get_shape_cache_size(), 4);
+
+ res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
+ // SSA can infer that sym_dim is 64 as both tensors
+ // use the same sym_dim
+ assertShapeEqual(res, {64, 64});
+ EXPECT_EQ(get_shape_cache_size(), 5);
+}
+
+TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
+ clear_shape_cache();
+
+ auto max_dim_op = getSchema(
+ "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
+ c10::IValue const_int = 1;
+ c10::IValue false_ival = false;
+
+ c10::optional<int64_t> sym_dim = c10::nullopt;
+ c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
+ c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
+
+ auto res =
+ calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
+ c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim});
+ assertShapeEqual(res->at(0), expected_res);
+ // res0 and res1 should share the same symbolic symbol
+ EXPECT_EQ(res->at(0), res->at(1));
+
+ // Also test that the shape cache also returns consistent result shapes
+ res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
+ assertShapeEqual(res->at(0), expected_res);
+ EXPECT_EQ(res->at(0), res->at(1));
+ EXPECT_EQ(get_shape_cache_size(), 1);
+}
} // namespace jit
} // namespace torch
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index e0927f1..ced45d3 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -306,6 +306,7 @@
"torch/csrc/jit/passes/integer_value_refinement.cpp",
"torch/csrc/jit/passes/replacement_of_old_operators.cpp",
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
+ "torch/csrc/jit/passes/symbolic_shape_cache.cpp",
"torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp",
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index 4d71b73..582137d 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -19,6 +19,7 @@
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
@@ -174,6 +175,17 @@
return symbolic_shape_analysis_test_mode;
}
+using SSArgument = c10::variant<ShapeArguments, IValue>;
+
+std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
+ if (const IValue* iv = c10::get_if<IValue>(&sa)) {
+ out << *iv;
+ } else {
+ out << c10::get<ShapeArguments>(sa);
+ }
+ return out;
+}
+
namespace {
bool isListOfInts(const TypePtr& type) {
@@ -244,8 +256,6 @@
return c10::SymbolicShape(output_shape);
}
-} // namespace
-
// Symbolic Shape Analysis works through iteratively partially evaluating
// a TorchScript shape compute graph by inputing properties from input
// Tensors. We can substitute in properties like `len(x)` and `x[1]`
@@ -260,17 +270,6 @@
// means that we do know its concrete value statically but we can asssign sets
// of tensor dimensions which must be equal at runtime.
-using SSArgument = c10::variant<ShapeArguments, IValue>;
-
-std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
- if (const IValue* iv = c10::get_if<IValue>(&sa)) {
- out << *iv;
- } else {
- out << c10::get<ShapeArguments>(sa);
- }
- return out;
-}
-
struct SymbolicShapeOpAnalyzer {
std::shared_ptr<Graph> shape_compute_graph_;
const FunctionSchema* schema_;
@@ -1058,6 +1057,7 @@
}
}
}
+} // namespace
void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
AliasDb db(graph);
@@ -1076,6 +1076,16 @@
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs) {
+ if (shapeComputeGraphForSchema(*schema) == c10::nullopt) {
+ // Avoid doing all this work for functions that don't have a
+ // supported schema
+ return c10::nullopt;
+ }
+
+ if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
+ return cached_ret_vec;
+ }
+
std::vector<SSArgument> ssa_args;
for (auto& arg : inputs) {
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
@@ -1087,7 +1097,11 @@
}
auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
- return op_analyzer.run(ssa_args);
+ auto res = op_analyzer.run(ssa_args);
+ if (res.has_value()) {
+ cache_shape_function(schema, inputs, res.value());
+ }
+ return res;
}
} // namespace jit
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.h b/torch/csrc/jit/passes/symbolic_shape_analysis.h
index 5e6239ab..99a7798 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.h
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.h
@@ -53,66 +53,5 @@
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs);
-
-struct TORCH_API CanonicalizedSymbolicShape {
- CanonicalizedSymbolicShape(
- c10::SymbolicShape& orig_shape,
- std::unordered_map<int64_t, int64_t>& ss_map) {
- init(orig_shape, ss_map);
- }
-
- CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
- std::unordered_map<int64_t, int64_t> new_ssmap;
- init(orig_shape, new_ssmap);
- }
-
- private:
- c10::optional<std::vector<int64_t>> values_;
- std::vector<bool> is_symbolic_;
-
- void init(
- c10::SymbolicShape& orig_shape,
- std::unordered_map<int64_t, int64_t>& ss_map) {
- auto sizes = orig_shape.sizes();
- if (!sizes) {
- values_ = c10::nullopt;
- return;
- }
- values_ = std::vector<int64_t>();
- int64_t cur_symbolic_index = -(int64_t)ss_map.size() - 1;
- for (auto& cur_shape : *sizes) {
- if (cur_shape.is_static()) {
- is_symbolic_.emplace_back(false);
- values_->push_back(cur_shape.static_size());
- } else {
- // Check for aliasing
- is_symbolic_.emplace_back(true);
- auto it = ss_map.find(cur_shape.value());
-
- if (it == ss_map.end()) {
- values_->push_back(cur_symbolic_index);
- ss_map.insert({cur_shape.value(), cur_symbolic_index});
- cur_symbolic_index--;
- } else {
- values_->push_back(it->second);
- }
- }
- }
- }
-
- friend bool operator==(
- const CanonicalizedSymbolicShape& a,
- const CanonicalizedSymbolicShape& b) {
- if (a.values_.has_value() != b.values_.has_value()) {
- return false;
- }
- if (!a.values_.has_value()) {
- return true;
- }
- return (
- a.values_.value() == b.values_.value() &&
- a.is_symbolic_ == b.is_symbolic_);
- };
-};
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.cpp b/torch/csrc/jit/passes/symbolic_shape_cache.cpp
new file mode 100644
index 0000000..62bf488
--- /dev/null
+++ b/torch/csrc/jit/passes/symbolic_shape_cache.cpp
@@ -0,0 +1,208 @@
+#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
+#include <torch/csrc/lazy/core/cache.h>
+
+// SHAPE CACHINHG CODE
+namespace torch {
+namespace jit {
+namespace {
+using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
+using CanonicalArgVec = std::vector<CanonicalArg>;
+using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
+using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
+
+CanonicalArgVec cannonicalizeVec(
+ const std::vector<SSAInput>& arg_vec,
+ std::unordered_map<int64_t, int64_t>& ss_map,
+ bool deep_copy = true) {
+ CanonicalArgVec canonical_args;
+ canonical_args.reserve(arg_vec.size());
+ for (auto& arg : arg_vec) {
+ if (const IValue* iv = c10::get_if<IValue>(&arg)) {
+ if (deep_copy) {
+ canonical_args.push_back(iv->deepcopy());
+ } else {
+ canonical_args.push_back(*iv);
+ }
+ } else {
+ auto& ss = c10::get<at::SymbolicShape>(arg);
+ canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
+ }
+ }
+ return canonical_args;
+}
+
+std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
+ const std::vector<at::SymbolicShape>& ret_vec,
+ std::unordered_map<int64_t, int64_t>& ss_map) {
+ std::vector<CanonicalizedSymbolicShape> canonical_rets;
+ canonical_rets.reserve(ret_vec.size());
+ for (auto& ss : ret_vec) {
+ canonical_rets.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
+ }
+ return canonical_rets;
+}
+
+struct ArgumentsHasher {
+ size_t operator()(const ShapeCacheKey& cacheKey) const {
+ // TODO: ignore arguments that are not used in shape function (not needed
+ // initially)
+ auto& op_name = std::get<0>(cacheKey);
+ auto& arg_vec = std::get<1>(cacheKey);
+
+ size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
+
+ hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
+ for (const CanonicalArg& arg : arg_vec) {
+ size_t cur_arg = 0;
+ if (const IValue* ival = c10::get_if<IValue>(&arg)) {
+ // IValue doesn't hash List (as Python doesn't), so we will do a custom
+ // list hash
+ if (ival->isList()) {
+ TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
+ cur_arg = ival->toListRef().size();
+ for (const IValue& elem_ival : ival->toListRef()) {
+ cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
+ }
+ } else {
+ cur_arg = IValue::hash(ival);
+ }
+ } else {
+ cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
+ }
+ hash_val = at::hash_combine(hash_val, cur_arg);
+ }
+ return hash_val;
+ }
+};
+
+using ShapeCache = lazy::Cache<
+ ShapeCacheKey,
+ std::vector<CanonicalizedSymbolicShape>,
+ ArgumentsHasher>;
+
+constexpr size_t kShapeCacheSize = 1024;
+ShapeCache shapeCache(kShapeCacheSize);
+
+ShapeCacheKey get_cache_key(
+ const FunctionSchema* schema,
+ const std::vector<SSAInput>& arg_vec,
+ std::unordered_map<int64_t, int64_t>& ss_map,
+ bool deep_copy = true) {
+ CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
+ return std::make_tuple(schema->operator_name(), canonical_args);
+}
+
+} // namespace
+
+TORCH_API void cache_shape_function(
+ const FunctionSchema* schema,
+ const std::vector<SSAInput>& arg_vec,
+ const std::vector<at::SymbolicShape>& ret_vec) {
+ // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
+ auto ss_map = std::unordered_map<int64_t, int64_t>();
+ auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
+ auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
+ cannonicalizeVec(ret_vec, ss_map));
+ shapeCache.Add(cache_key, can_ret_vec);
+}
+
+TORCH_API c10::optional<std::vector<at::SymbolicShape>>
+get_cached_shape_function(
+ const FunctionSchema* schema,
+ const std::vector<SSAInput>& arg_vec) {
+ // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
+ // ss_map and inverse_ss_map
+ auto ss_map = std::unordered_map<int64_t, int64_t>();
+ auto cache_key =
+ get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
+ auto cached_ret_vec = shapeCache.Get(cache_key);
+ if (cached_ret_vec == nullptr) {
+ return c10::nullopt;
+ }
+ // Decanonicalize the return values
+ auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
+ for (auto& ss_val : ss_map) {
+ inverse_ss_map[ss_val.second] = ss_val.first;
+ }
+ std::vector<at::SymbolicShape> ret_vec;
+ for (auto& css : *cached_ret_vec) {
+ ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
+ }
+ return ret_vec;
+}
+
+// Function only to access the cache, used for testing
+TORCH_API void clear_shape_cache() {
+ shapeCache.Clear();
+}
+
+TORCH_API size_t get_shape_cache_size() {
+ return shapeCache.Numel();
+}
+
+void CanonicalizedSymbolicShape::init(
+ const c10::SymbolicShape& orig_shape,
+ std::unordered_map<int64_t, int64_t>& ss_map) {
+ auto sizes = orig_shape.sizes();
+ if (!sizes) {
+ values_ = c10::nullopt;
+ return;
+ }
+ values_ = std::vector<int64_t>();
+ int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
+ for (auto& cur_shape : *sizes) {
+ if (cur_shape.is_static()) {
+ values_->push_back(cur_shape.static_size());
+ } else {
+ // Check for aliasing
+ auto it = ss_map.find(cur_shape.value());
+
+ if (it == ss_map.end()) {
+ values_->push_back(cur_symbolic_index);
+ ss_map.insert({cur_shape.value(), cur_symbolic_index});
+ cur_symbolic_index--;
+ } else {
+ values_->push_back(it->second);
+ }
+ }
+ }
+}
+
+c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
+ std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
+ if (!values_.has_value()) {
+ return c10::SymbolicShape();
+ }
+ std::vector<at::ShapeSymbol> sizes;
+ for (long long cur_val : *values_) {
+ if (cur_val >= 0) {
+ sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
+ continue;
+ }
+ auto res = inverse_ss_map.find(cur_val);
+ if (res != inverse_ss_map.end()) {
+ sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
+ } else {
+ auto new_symbol = at::ShapeSymbol::newSymbol();
+ inverse_ss_map.insert({cur_val, new_symbol.value()});
+ sizes.push_back(new_symbol);
+ }
+ }
+ return c10::SymbolicShape(std::move(sizes));
+}
+
+size_t CanonicalizedSymbolicShape::hash() const {
+ if (!values_.has_value()) {
+ return 0x8cc80c80; // random value to prevent hash collisions
+ }
+ return c10::hash<std::vector<int64_t>>()(values_.value());
+}
+
+bool operator==(
+ const CanonicalizedSymbolicShape& a,
+ const CanonicalizedSymbolicShape& b) {
+ return a.values_ == b.values_;
+};
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.h b/torch/csrc/jit/passes/symbolic_shape_cache.h
new file mode 100644
index 0000000..02e00ac
--- /dev/null
+++ b/torch/csrc/jit/passes/symbolic_shape_cache.h
@@ -0,0 +1,57 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+
+namespace torch {
+namespace jit {
+
+struct TORCH_API CanonicalizedSymbolicShape {
+ // TODO: Consider in the future if it is reasonable to
+ // merge code with SymbolicShape or VaryingShape while keeping
+ // the two not implicitly convertable (and cause bugs).
+ CanonicalizedSymbolicShape(
+ const c10::SymbolicShape& orig_shape,
+ std::unordered_map<int64_t, int64_t>& ss_map) {
+ init(orig_shape, ss_map);
+ }
+
+ CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
+ std::unordered_map<int64_t, int64_t> new_ssmap;
+ init(orig_shape, new_ssmap);
+ }
+
+ size_t hash() const;
+
+ c10::SymbolicShape toSymbolicShape(
+ std::unordered_map<int64_t, int64_t>& inverse_ss_map) const;
+
+ TORCH_API friend bool operator==(
+ const CanonicalizedSymbolicShape& a,
+ const CanonicalizedSymbolicShape& b);
+
+ private:
+ c10::optional<std::vector<int64_t>> values_;
+
+ void init(
+ const c10::SymbolicShape& orig_shape,
+ std::unordered_map<int64_t, int64_t>& ss_map);
+};
+
+// SHAPE CACHE API
+TORCH_API c10::optional<std::vector<at::SymbolicShape>>
+get_cached_shape_function(
+ const FunctionSchema* schema,
+ const std::vector<SSAInput>& arg_vec);
+
+TORCH_API void cache_shape_function(
+ const FunctionSchema* schema,
+ const std::vector<SSAInput>& arg_vec,
+ const std::vector<at::SymbolicShape>& ret_vec);
+
+// For use in test code
+TORCH_API void clear_shape_cache();
+TORCH_API size_t get_shape_cache_size();
+
+} // namespace jit
+} // namespace torch