StmInt support for InferSize (#84903)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84903
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h
index e0bedb7..594b873 100644
--- a/aten/src/ATen/InferSize.h
+++ b/aten/src/ATen/InferSize.h
@@ -2,6 +2,8 @@
 
 #include <ATen/DimVector.h>
 #include <c10/core/ScalarType.h>
+#include <c10/core/SymIntArrayRef.h>
+#include <c10/util/DimVector.h>
 #include <c10/util/Optional.h>
 #include <sstream>
 #include <vector>
@@ -14,9 +16,13 @@
 // templated to handle std::vector<int64_t> and DimVector use cases, see
 // below
 //
-template <typename ResultVec>
-inline void infer_size_impl(IntArrayRef shape, int64_t numel, ResultVec& res) {
-  int64_t newsize = 1;
+template <typename InputArrayRef, typename NumelType, typename ResultVec>
+inline void infer_size_impl(
+    InputArrayRef shape,
+    NumelType numel,
+    ResultVec& res) {
+  NumelType newsize = 1;
+  // N.B. this is an index, not a sym dim!
   auto infer_dim = c10::optional<int64_t>();
   for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
     if (shape[dim] == -1) {
@@ -69,4 +75,13 @@
   return res;
 }
 
+inline at::SymDimVector infer_size_dv(
+    c10::SymIntArrayRef shape,
+    c10::SymInt numel) {
+  auto res = at::SymDimVector(shape);
+  infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
+      shape, numel, res);
+  return res;
+}
+
 } // namespace at