StmInt support for InferSize (#84903)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84903
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h
index e0bedb7..594b873 100644
--- a/aten/src/ATen/InferSize.h
+++ b/aten/src/ATen/InferSize.h
@@ -2,6 +2,8 @@
#include <ATen/DimVector.h>
#include <c10/core/ScalarType.h>
+#include <c10/core/SymIntArrayRef.h>
+#include <c10/util/DimVector.h>
#include <c10/util/Optional.h>
#include <sstream>
#include <vector>
@@ -14,9 +16,13 @@
// templated to handle std::vector<int64_t> and DimVector use cases, see
// below
//
-template <typename ResultVec>
-inline void infer_size_impl(IntArrayRef shape, int64_t numel, ResultVec& res) {
- int64_t newsize = 1;
+template <typename InputArrayRef, typename NumelType, typename ResultVec>
+inline void infer_size_impl(
+ InputArrayRef shape,
+ NumelType numel,
+ ResultVec& res) {
+ NumelType newsize = 1;
+ // N.B. this is an index, not a sym dim!
auto infer_dim = c10::optional<int64_t>();
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (shape[dim] == -1) {
@@ -69,4 +75,13 @@
return res;
}
+inline at::SymDimVector infer_size_dv(
+ c10::SymIntArrayRef shape,
+ c10::SymInt numel) {
+ auto res = at::SymDimVector(shape);
+ infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
+ shape, numel, res);
+ return res;
+}
+
} // namespace at