Revert "Revert "Enabling SymInt in autograd; take 3 (#81145)"" ; make sure is_intlist checks for symintnodes (#82189)
### Description
<!-- What did you change and why was it needed? -->
### Issue
<!-- Link to Issue ticket or RFP -->
### Testing
<!-- How did you test your change? -->
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82189
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index f52af4c..0740773 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -184,10 +184,13 @@
return self_physical.getPhysicalToLogicalMap().apply(result);
}
-Tensor expand_batching_rule_symint(const Tensor& self, SymIntArrayRef psize, bool implicit) {
+Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
}
+Tensor sum_symint_batching_rule(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
+ return sum_batching_rule(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype);
+}
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
@@ -468,6 +471,10 @@
return self_physical.getPhysicalToLogicalMap().apply(result);
}
+Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
+ return view_batching_rule(self, asIntArrayRefSlow(size));
+}
+
Tensor view_as_complex_batching_rule(const Tensor& self) {
// guard against the user passing in a batch of scalar tensors with batch
// size equal to 2.
@@ -1082,6 +1089,7 @@
m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);
m.impl("sum.dim_IntList", sum_batching_rule);
+ m.impl("sum.SymInt", sum_symint_batching_rule);
m.impl("is_complex", native::is_complex);
// inplace operations
@@ -1096,7 +1104,7 @@
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
- m.impl("expand.SymInt", expand_batching_rule_symint);
+ m.impl("expand.SymInt", expand_symint_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@@ -1125,6 +1133,7 @@
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_batching_rule);
+ m.impl("view.SymInt", view_symint_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
// clamp operations
diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h
index c7b9e90..7a81076 100644
--- a/aten/src/ATen/ExpandUtils.h
+++ b/aten/src/ATen/ExpandUtils.h
@@ -437,17 +437,16 @@
return result;
}
-// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
-// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
static inline Tensor sum_to(
Tensor tensor,
- const IntArrayRef shape,
+ const c10::SymIntArrayRef shape,
bool always_return_non_view = false) {
if (shape.size() == 0) {
return tensor.sum();
}
- c10::SmallVector<int64_t, 8> reduce_dims;
- const at::IntArrayRef sizes = tensor.sizes();
+
+ auto sizes = tensor.sym_sizes();
+ c10::SmallVector<c10::SymInt, 8> reduce_dims;
const int64_t leading_dims = sizes.size() - shape.size();
for (const auto i : c10::irange(leading_dims)) {
reduce_dims.push_back(i);
@@ -457,29 +456,44 @@
reduce_dims.push_back(i);
}
}
+
if (!reduce_dims.empty()) {
- tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
+ tensor = tensor.sum_symint(reduce_dims, /*keepdim=*/true);
}
+
if (always_return_non_view) {
// This is only actually used by the functionalization pass.
// We want to be able to guarantee that this function doesn't return a view
// of the input.
- return leading_dims > 0 ? at::view_copy(tensor, shape) : tensor.clone();
+ return leading_dims > 0 ? at::view_copy_symint(tensor, shape)
+ : tensor.clone();
} else {
- return leading_dims > 0 ? tensor.view(shape) : tensor;
+ return leading_dims > 0 ? tensor.view_symint(shape) : tensor;
}
}
-// True if `shape` can be broadcasted to `desired`
-static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
+// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
+// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
+static inline Tensor sum_to(
+ Tensor tensor,
+ const IntArrayRef shape,
+ bool always_return_non_view = false) {
+ auto sym_size = c10::SymIntArrayRef(
+ reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
+ return sum_to(tensor, sym_size, always_return_non_view);
+}
+
+static inline bool is_expandable_to(
+ SymIntArrayRef shape,
+ c10::SymIntArrayRef desired) {
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
return false;
}
for (const auto i : c10::irange(ndim)) {
- int64_t size = shape[ndim - i - 1];
- int64_t target = desired[target_dim - i - 1];
+ auto size = shape[ndim - i - 1];
+ auto target = desired[target_dim - i - 1];
if (size != target && size != 1) {
return false;
}
@@ -487,4 +501,12 @@
return true;
}
+static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
+ auto sym_shape = c10::SymIntArrayRef(
+ reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
+ auto sym_desired = c10::SymIntArrayRef(
+ reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
+ return is_expandable_to(sym_shape, sym_desired);
+}
+
} // namespace at
diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp
index a8e3c4d..471c74a 100644
--- a/aten/src/ATen/FunctionalInverses.cpp
+++ b/aten/src/ATen/FunctionalInverses.cpp
@@ -299,6 +299,14 @@
}
}
+Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
+ if (reapply_views) {
+ return mutated_view.view_symint(base.sym_sizes());
+ } else {
+ return at::view_copy_symint(mutated_view, base.sym_sizes());
+ }
+}
+
Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
if (reapply_views) {
return mutated_view.view(base.scalar_type());
diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp
index bb67593..a9ae2f1 100644
--- a/aten/src/ATen/core/NamedRegistrations.cpp
+++ b/aten/src/ATen/core/NamedRegistrations.cpp
@@ -467,6 +467,7 @@
m.impl("sum.IntList_out", CppFunction::makeFallthrough());
m.impl("sum.dim_DimnameList", CppFunction::makeFallthrough());
m.impl("sum.dim_IntList", CppFunction::makeFallthrough());
+ m.impl("sum.SymInt", CppFunction::makeFallthrough());
m.impl("t", CppFunction::makeFallthrough());
m.impl("tan", CppFunction::makeFallthrough());
m.impl("tan.out", CppFunction::makeFallthrough());
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index cced7d6..913391c 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1079,6 +1079,10 @@
return at::sum(self, dimnames_to_positions(self, dim), keepdim, dtype);
}
+Tensor sum_symint(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
+ return at::sum(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype);
+}
+
Tensor& sum_out(const Tensor& self, DimnameList dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index 9536da8..11f2a6d 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -1083,6 +1083,14 @@
return result.zero_();
}
+Tensor zeros_symint(c10::SymIntArrayRef size,
+ c10::optional<ScalarType> dtype,
+ c10::optional<Layout> layout,
+ c10::optional<Device> device,
+ c10::optional<bool> pin_memory) {
+ return zeros(asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+}
+
Tensor _efficientzerotensor(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index c46d2dc..285f645 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -27,6 +27,7 @@
#include <algorithm>
#include <cstdint>
#include <vector>
+#include <c10/util/StringUtil.h>
namespace at {
namespace meta {
@@ -3105,6 +3106,11 @@
return view_impl(self, size);
}
+Tensor view_symint(const Tensor& self,
+ c10::SymIntArrayRef size) {
+ return self.view(c10::asIntArrayRefSlow(size));
+}
+
Tensor alias(const Tensor& self) {
return alias_with_sizes_and_strides(self, self.sizes(), self.strides());
}
diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp
index 7676edc..ec3c58e 100644
--- a/aten/src/ATen/native/mkldnn/TensorShape.cpp
+++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp
@@ -2,6 +2,7 @@
#include <ATen/Config.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
+#include <c10/core/SymIntArrayRef.h>
#if !AT_MKLDNN_ENABLED()
@@ -86,3 +87,15 @@
} // namespace at
#endif // AT_MKLDNN_ENABLED
+
+
+namespace at {
+namespace native {
+
+
+Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) {
+ return mkldnn_view(self, c10::asIntArrayRefSlow(size));
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/mkldnn/TensorShape.h b/aten/src/ATen/native/mkldnn/TensorShape.h
index bbb8ea9..92af7e2 100644
--- a/aten/src/ATen/native/mkldnn/TensorShape.h
+++ b/aten/src/ATen/native/mkldnn/TensorShape.h
@@ -1,12 +1,15 @@
#pragma once
#include <ATen/ATen.h>
+#include <c10/core/SymIntArrayRef.h>
namespace at {
namespace native {
Tensor mkldnn_view(const Tensor& self, IntArrayRef size);
+Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size);
+
Tensor mkldnn_clone(const Tensor& self);
} // namespace native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f1e86fa..1f32e6e 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4588,6 +4588,12 @@
CompositeExplicitAutograd: sum
SparseCsrCPU, SparseCsrCUDA: sum_csr
+- func: sum.SymInt(Tensor self, SymInt[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+ device_check: NoCheck # TensorIterator
+ variants: function, method
+ dispatch:
+ CompositeExplicitAutograd: sum_symint
+
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: sum.IntList_out
device_check: NoCheck # TensorIterator
@@ -5197,6 +5203,8 @@
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
@@ -6448,6 +6456,14 @@
CUDA: masked_softmax_backward_cuda
CPU: masked_softmax_backward_cpu
+- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
+ variants: method
+ device_check: NoCheck
+ device_guard: False
+ dispatch:
+ CompositeExplicitAutograd: view_symint
+ MkldnnCPU: mkldnn_view_symint
+
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
variants: method
device_check: NoCheck
@@ -12335,6 +12351,13 @@
CompositeExplicitAutograd: _neg_view_copy_out
+- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor
+ variants: function
+ dispatch:
+ CompositeExplicitAutograd: view_copy_SymInt
+ tags: view_copy
+
+
- func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp
index fe550a6..fc2bbdb 100644
--- a/c10/core/SymInt.cpp
+++ b/c10/core/SymInt.cpp
@@ -25,6 +25,56 @@
return SymInt(data_ + sci.data_);
}
+bool SymInt::operator!=(SymInt sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return data_ != sci.data_;
+ }
+ // TODO: This is way to much boilerplate
+ std::shared_ptr<SymbolicIntNode> a =
+ is_symbolic() ? toSymbolicIntNode() : nullptr;
+ std::shared_ptr<SymbolicIntNode> b =
+ sci.is_symbolic() ? sci.toSymbolicIntNode() : nullptr;
+
+ SymbolicIntNode* common = a ? a.get() : b.get();
+ // TODO: technically we need to check that the classes match
+ if (!a) {
+ a = common->wrap(data_);
+ toSymInt(a); //
+ }
+ if (!b) {
+ b = common->wrap(sci.data_);
+ toSymInt(b);
+ }
+
+ auto c = a->ne(b);
+ return c->bool_();
+}
+
+bool SymInt::operator==(SymInt sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return data_ == sci.data_;
+ }
+ // TODO: This is way to much boilerplate
+ std::shared_ptr<SymbolicIntNode> a =
+ is_symbolic() ? toSymbolicIntNode() : nullptr;
+ std::shared_ptr<SymbolicIntNode> b =
+ sci.is_symbolic() ? sci.toSymbolicIntNode() : nullptr;
+
+ SymbolicIntNode* common = a ? a.get() : b.get();
+ // TODO: technically we need to check that the classes match
+ if (!a) {
+ a = common->wrap(data_);
+ toSymInt(a); //
+ }
+ if (!b) {
+ b = common->wrap(sci.data_);
+ toSymInt(b);
+ }
+
+ auto c = a->eq(b);
+ return c->bool_();
+}
+
SymInt SymInt::operator*(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymInt(data_ * sci.data_);
@@ -68,13 +118,11 @@
}
bool SymInt::operator==(int64_t sci) const {
- TORCH_CHECK(!this->is_symbolic(), "Symbolic eq isn't supported yet");
- return data_ == sci;
+ return *this == c10::SymInt(sci);
}
bool SymInt::operator!=(int64_t sci) const {
- TORCH_CHECK(!this->is_symbolic(), "Symbolic neq isn't supported yet");
- return data_ != sci;
+ return *this != c10::SymInt(sci);
}
SymInt SymInt::operator*(int64_t sci) const {
diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h
index e9e7d09..50a695d 100644
--- a/c10/core/SymInt.h
+++ b/c10/core/SymInt.h
@@ -39,16 +39,10 @@
return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM;
}
- bool operator==(const SymInt& p2) const {
- return data_ == p2.data_;
- }
-
- bool operator!=(const SymInt& p2) const {
- return data_ != p2.data_;
- }
-
SymInt operator+(SymInt sci) const;
SymInt operator*(SymInt sci) const;
+ bool operator==(SymInt sci) const;
+ bool operator!=(SymInt p2) const;
bool operator<(SymInt sci) const;
void operator*=(SymInt sci);
diff --git a/c10/core/SymIntArrayRef.cpp b/c10/core/SymIntArrayRef.cpp
index 151be20..7898a8a 100644
--- a/c10/core/SymIntArrayRef.cpp
+++ b/c10/core/SymIntArrayRef.cpp
@@ -31,8 +31,4 @@
return os;
}
-std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list) {
- return out << list.wrapped_symint_array_ref;
-}
-
} // namespace c10
diff --git a/c10/core/SymIntArrayRef.h b/c10/core/SymIntArrayRef.h
index 77c151f..bf2eb65 100644
--- a/c10/core/SymIntArrayRef.h
+++ b/c10/core/SymIntArrayRef.h
@@ -63,6 +63,11 @@
size_t length)
: wrapped_symint_array_ref(data, length) {}
+ template <typename U>
+ /* implicit */ SymIntArrayRef(
+ const SmallVectorTemplateCommon<c10::SymInt, U>& Vec)
+ : wrapped_symint_array_ref(Vec) {}
+
/// Construct an SymIntArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* begin,
@@ -193,6 +198,10 @@
TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
c10::SymIntArrayRef ar);
-std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list);
+inline std::ostream& operator<<(
+ std::ostream& out,
+ const c10::SymIntArrayRef& list) {
+ return out << list.wrapped_symint_array_ref;
+}
} // namespace c10
diff --git a/c10/core/SymbolicIntNode.h b/c10/core/SymbolicIntNode.h
index b31d4b9..5a91644 100644
--- a/c10/core/SymbolicIntNode.h
+++ b/c10/core/SymbolicIntNode.h
@@ -38,6 +38,10 @@
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
+ virtual std::shared_ptr<SymbolicIntNode> ne(
+ const std::shared_ptr<SymbolicIntNode>& other) {
+ TORCH_CHECK(false, "NYI");
+ };
virtual std::shared_ptr<SymbolicIntNode> gt(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py
index d038def..8de6860 100644
--- a/test/test_dynamic_shapes.py
+++ b/test/test_dynamic_shapes.py
@@ -230,6 +230,9 @@
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
+ z = y.expand((y.shape[1],))
+ z = y.expand(y.shape[1])
+
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index b8d29a0..d6be39c 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -130,7 +130,7 @@
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
- empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
+ empty_1 = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
@@ -151,8 +151,8 @@
def forward(self, a_1):
- empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
- empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
+ empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
+ empty_1 = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
@@ -258,7 +258,7 @@
def forward(self, a_1):
- empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
""")
@@ -653,8 +653,8 @@
def forward(self, a_1):
- expand_copy_sym_int = torch.ops.aten.expand_copy.SymInt(a_1, [2, 2]); a_1 = None
- return expand_copy_sym_int
+ expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
+ return expand_copy_default
""")
def test_fill_(self):
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index e10b5cf..de720e5 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -746,7 +746,8 @@
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(LoggingTensorMode()):
torch.empty([])
- self.assertExpectedInline('\n'.join(logs), ("$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32," +
+
+ self.assertExpectedInline('\n'.join(logs), ("$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32," +
" device=device(type='cpu'), pin_memory=False)"))
def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None:
@@ -774,8 +775,8 @@
x + y
self.assertExpectedInline('\n'.join(logs), """\
-$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
-$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
+$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
+$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
@@ -786,7 +787,7 @@
torch.empty([])
x + y
self.assertExpectedInline('\n'.join(logs), """\
-$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
+$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
x = torch.randn([])
@@ -798,8 +799,8 @@
x + y
self.assertExpectedInline('\n'.join(logs2), """\
-$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
-$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
+$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
+$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index b568f41..a36e9d7 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1527,6 +1527,10 @@
self: grad.expand(self.sizes())
result: auto_linear
+- name: sum.SymInt(Tensor self, SymInt[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+ self: sum_backward(grad, self.sym_sizes(), dim, keepdim)
+ result: auto_linear
+
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: sum_backward(grad, self.sizes(), dim, keepdim)
result: auto_linear
@@ -1713,6 +1717,12 @@
self: grad.reshape(self.sizes())
result: auto_linear
+- name: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
+ # TODO: add proper double backward for view.SymInt
+ # by SymIntizing `reshape`
+ self: grad.reshape(c10::asIntArrayRefSlow(self.sym_sizes()))
+ result: auto_linear
+
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False]
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index fb8c93c..5122276 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -27,6 +27,7 @@
optionalIntArrayRefT,
scalarT,
stringT,
+ symIntArrayRefT,
tensorListT,
tensorT,
)
@@ -281,6 +282,20 @@
return tup;
"""
+GETTER_BODY_ARRAYREF_SYMINT = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i : c10::irange(prop.size())) {
+ auto si = prop[i];
+ if (si.is_symbolic()) {
+ auto py_symint = py::cast(si.toSymbolicIntNode()).release().ptr();
+ PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
+ } else {
+ PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.data()));
+ }
+}
+return tup;
+"""
+
GETTER_BODY_ARRAYREF_DOUBLE = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) {
@@ -520,6 +535,13 @@
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
)
)
+ elif type == BaseCType(symIntArrayRefT):
+ saved_variables.append(f"std::vector<c10::SymInt> {name};")
+ getter_definitions.append(
+ GETTER_DEFINITION.substitute(
+ op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
+ )
+ )
elif type == BaseCType(optionalIntArrayRefT):
saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
getter_definitions.append(
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 3c58359..e3cb8b8 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -1123,9 +1123,9 @@
str(t1) == "Tensor[]"
and str(t2).find("[]") != -1
or
- # Prioritize SymIntArrayRef overload over IntArrayRef
- str(t1) == "int[]"
- and str(t2) == "SymInt[]"
+ # Prioritize IntArrayRef overload over SymIntArrayRef
+ str(t1) == "SymInt[]"
+ and str(t2) == "int[]"
)
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 1c2e593..ce74c77 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -48,6 +48,7 @@
scalarT,
SpecialArgName,
stringT,
+ symIntArrayRefT,
tensorListT,
tensorT,
TupleCType,
@@ -1091,6 +1092,8 @@
name += "_"
elif type == BaseCType(intArrayRefT):
expr = expr + ".vec()"
+ elif type == BaseCType(symIntArrayRefT):
+ expr = expr + ".vec()"
elif type == BaseCType(stringT):
expr = f"std::string({expr})"
elif type == OptionalCType(BaseCType(stringT)):
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py
index 09685f6..7bf43cf 100644
--- a/tools/autograd/load_derivatives.py
+++ b/tools/autograd/load_derivatives.py
@@ -28,6 +28,7 @@
scalarTypeT,
SpecialArgName,
stringT,
+ symIntArrayRefT,
tensorGeometryT,
tensorOptionsT,
typeAndSizeT,
@@ -696,6 +697,14 @@
"nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)),
},
),
+ # replace self.sym_sizes() with self_sym_sizes
+ (
+ r"{}.sym_sizes\(\)",
+ {
+ "suffix": "_sym_sizes",
+ "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
+ },
+ ),
# replace self->sizes() with self_sizes_opt
(
r"{}->sizes\(\)",
diff --git a/tools/autograd/templates/python_functions.cpp b/tools/autograd/templates/python_functions.cpp
index 3be7a01..e913d08 100644
--- a/tools/autograd/templates/python_functions.cpp
+++ b/tools/autograd/templates/python_functions.cpp
@@ -5,6 +5,7 @@
#include <Python.h>
#include <ATen/ATen.h>
+#include <c10/core/SymbolicIntNode.h>
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include <torch/csrc/autograd/python_variable.h>
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 9a563ab..dae458a 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -595,6 +595,21 @@
return grad.expand(sizes);
}
+Tensor sum_backward(
+ const Tensor& grad,
+ c10::SymIntArrayRef sizes,
+ c10::SymIntArrayRef dims,
+ bool keepdim) {
+ if (!keepdim && sizes.size() > 0 && dims.size() > 0) {
+ // we are only using `keepdim=true` path for SymInts for now
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "Only the keepdim=true path is implemented to support symints in autograd");
+ } else {
+ return grad.expand_symint(sizes);
+ }
+}
+
Tensor nansum_backward(
const Tensor& grad,
const Tensor& self,
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 5df214b..07e5ca4 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -157,6 +157,11 @@
at::IntArrayRef sizes,
at::OptionalIntArrayRef opt_dims,
bool keepdim);
+at::Tensor sum_backward(
+ const at::Tensor& grad,
+ c10::SymIntArrayRef sizes,
+ c10::SymIntArrayRef dims,
+ bool keepdim);
at::Tensor nansum_backward(
const at::Tensor& grad,
const at::Tensor& self,
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index f98da97..575b0c5 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -186,11 +186,11 @@
/// of the new input.
uint32_t add_input_metadata(
const at::TensorOptions& options,
- at::IntArrayRef shape,
+ c10::SymIntArrayRef shape,
bool is_tensor_subclass) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
- auto meta_shape = MetadataShape{c10::in_place_type<at::DimVector>, shape};
+ auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape};
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
return input_nr;
}
diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h
index 917acc2..7cb9e8a 100644
--- a/torch/csrc/autograd/input_metadata.h
+++ b/torch/csrc/autograd/input_metadata.h
@@ -6,9 +6,12 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
+#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
+#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
+#include <c10/util/SmallVector.h>
#include <c10/util/variant.h>
#ifndef AT_PER_OPERATOR_HEADERS
@@ -23,7 +26,8 @@
namespace torch {
namespace autograd {
-using MetadataShape = c10::variant<at::DimVector, at::Tensor>;
+using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
+using MetadataShape = c10::variant<SymIntSmallVec, at::Tensor>;
/**
* Records TensorOptions, shape of the tensor, whether or not the Python
@@ -81,7 +85,7 @@
TORCH_CHECK(
!is_nested_tensor(),
"Zeros is not currently supported for nested tensors.")
- return at::zeros(shape_as_dim_vector(), options_);
+ return at::zeros_symint(shape_as_dim_vector(), options_);
}
bool is_same_shape(const at::Tensor& grad) const {
@@ -92,7 +96,7 @@
return at::native::get_nested_size_tensor(grad).is_same_size(
shape_as_tensor());
}
- return grad.sizes().equals(shape_as_dim_vector());
+ return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool is_expandable_to_shape(const at::Tensor& grad) const {
// Currently NestedTensors are not expandable. If this support is added then
@@ -102,7 +106,7 @@
"Both grad and InputMetadata need to be either nested or non nested tensors.")
return grad.is_nested()
? false
- : at::is_expandable_to(shape_as_dim_vector(), grad.sizes());
+ : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor reduce_grad(at::Tensor& grad) const {
@@ -127,7 +131,7 @@
if (is_nested_tensor()) {
ss << shape_as_tensor();
} else {
- ss << shape_as_dim_vector();
+ ss << c10::asIntArrayRefSlow(shape_as_dim_vector());
}
return ss;
}
@@ -141,12 +145,14 @@
auto nested_size = at::native::get_nested_size_tensor(input);
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
}
- return MetadataShape{c10::in_place_type<at::DimVector>, input.sizes()};
+ return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
- at::DimVector shape_as_dim_vector() const {
- return c10::get<at::DimVector>(shape_);
+ c10::SymIntArrayRef shape_as_dim_vector() const {
+ const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
+ return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
+
at::Tensor shape_as_tensor() const {
return c10::get<at::Tensor>(shape_);
}
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index ce62fec..aa3f17f 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -692,7 +692,8 @@
torch::autograd::collect_next_edges(view_info.base_));
fn->add_input_metadata(
view_info.base_.options(),
- self.sizes(), // Note: sizes(), not base_.sizes(), is intentional
+ self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
+ // intentional
self.unsafeGetTensorImpl()->is_python_dispatch());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp
index 85a5e9f..6f3d13d 100644
--- a/torch/csrc/lazy/core/tensor_impl.cpp
+++ b/torch/csrc/lazy/core/tensor_impl.cpp
@@ -143,9 +143,13 @@
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
- return FLAGS_ltc_enable_symbolic_shapes
- ? c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size())
- : TensorImpl::sym_sizes_default();
+ if (FLAGS_ltc_enable_symbolic_shapes) {
+ return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
+ }
+
+ // return upper bound
+ const_cast<LTCTensorImpl*>(this)->setup_size_properties();
+ return TensorImpl::sym_sizes_default();
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index d2b985c..6e14b7d 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -649,12 +649,29 @@
static bool is_int_list(PyObject* obj, int broadcast_size) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
- if (PySequence_Size(obj) == 0) {
+ auto len = PySequence_Size(obj);
+ if (len == 0) {
return true;
}
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
+ bool int_first = false;
if (THPUtils_checkIndex(item.ptr())) {
+ // we still have to check that the rest of items are NOT symint nodes
+ int_first = true;
+ }
+
+ // Make sure none of the later arguments are SymInt
+ // NB: do NOT check that the later arguments are ints, as this is
+ // BC-breaking for FX
+ for (int i = 1; i < len; i++) {
+ if (torch::is_symint_node(
+ py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
+ return false;
+ }
+ }
+
+ if (int_first) {
return true;
}
@@ -1227,10 +1244,14 @@
// if there is a single positional IntArrayRef argument, i.e. expand(..),
// view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
// expand((5,3))
+ int int_list_overload = false;
if (max_pos_args == 1 &&
(params[0].type_ == ParameterType::INT_LIST ||
params[0].type_ == ParameterType::SYM_INT_LIST)) {
allow_varargs_intlist = true;
+ if (params[0].type_ == ParameterType::INT_LIST) {
+ int_list_overload = true;
+ }
}
if (nargs > max_pos_args && !allow_varargs_intlist) {
@@ -1287,7 +1308,8 @@
// should avoid having complex signatures that make use of it...
} else if (
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
- is_int_or_symint(obj)) {
+ (is_int_list(args, param.size) ||
+ ((is_int_or_symint_list(args, param.size) && !int_list_overload)))) {
// take all positional arguments as this parameter
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
dst[i++] = args;
diff --git a/torchgen/gen.py b/torchgen/gen.py
index 5bc5f62..1800ce3 100644
--- a/torchgen/gen.py
+++ b/torchgen/gen.py
@@ -2302,8 +2302,13 @@
if g1.view_copy is None or g2.view_copy is None:
continue
# TODO: make this more first class in the data model
- same_base_op = str(g1.view_copy.func.name.name) == str(
- g2.view_copy.func.name.name
+ g1_base_name = str(g1.view_copy.func.name.name)
+ g2_base_name = str(g2.view_copy.func.name.name)
+
+ same_base_op = (
+ g1_base_name == g2_base_name
+ and g1.view_copy.func.arguments.symints_to_ints()
+ == g2.view_copy.func.arguments.symints_to_ints()
)
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
diff --git a/torchgen/model.py b/torchgen/model.py
index de09267..8589706 100644
--- a/torchgen/model.py
+++ b/torchgen/model.py
@@ -791,6 +791,9 @@
backend_metadata,
)
+ def symints_to_ints(self) -> "NativeFunction":
+ return dataclasses.replace(self, func=self.func.symints_to_ints())
+
def validate_unstructured(self) -> None:
# TODO: probably better to accumulate these errors and report them all
# at once
@@ -1184,6 +1187,9 @@
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
+ def symints_to_ints(self) -> "FunctionSchema":
+ return dataclasses.replace(self, arguments=self.arguments.symints_to_ints())
+
@staticmethod
def parse(func: str) -> "FunctionSchema":
# We should probably get a proper parser here
@@ -1623,6 +1629,9 @@
def is_list_like(self) -> Optional["ListType"]:
raise NotImplementedError
+ def symint_to_int(self) -> "Type":
+ raise NotImplementedError
+
# Base types are simple, atomic types with no further structure
BaseTy = Enum(
@@ -1663,6 +1672,11 @@
def is_nullable(self) -> bool:
return False
+ def symint_to_int(self) -> "BaseType":
+ if self.name == BaseTy.SymInt:
+ return BaseType(BaseTy.int)
+ return self
+
def is_list_like(self) -> Optional["ListType"]:
return None
@@ -1681,6 +1695,9 @@
def is_nullable(self) -> bool:
return True
+ def symint_to_int(self) -> "Type":
+ return dataclasses.replace(self, elem=self.elem.symint_to_int())
+
def is_list_like(self) -> Optional["ListType"]:
return self.elem.is_list_like()
@@ -1707,6 +1724,9 @@
def is_nullable(self) -> bool:
return self.elem.is_nullable()
+ def symint_to_int(self) -> "ListType":
+ return ListType(self.elem.symint_to_int(), self.size)
+
def is_list_like(self) -> Optional["ListType"]:
return self
@@ -1780,6 +1800,9 @@
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
+ def symint_to_int(self) -> "Argument":
+ return dataclasses.replace(self, type=self.type.symint_to_int())
+
def __str__(self) -> str:
type = f"{self.type}"
if self.annotation:
@@ -1969,6 +1992,37 @@
if a.annotation is not None and a.annotation.is_write
]
+ def symints_to_ints(self) -> "Arguments":
+ arguments = self
+
+ if arguments.self_arg:
+ arguments = dataclasses.replace(
+ arguments,
+ pre_self_positional=[
+ x.symint_to_int() for x in arguments.pre_self_positional
+ ],
+ )
+
+ if self.tensor_options:
+ arguments = dataclasses.replace(
+ arguments,
+ post_tensor_options_kwarg_only=[
+ x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only
+ ],
+ )
+
+ arguments = dataclasses.replace(
+ arguments,
+ post_self_positional=[
+ x.symint_to_int() for x in arguments.post_self_positional
+ ],
+ pre_tensor_options_kwarg_only=[
+ x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only
+ ],
+ )
+
+ return arguments
+
def signature(self, *, strip_default: bool = False) -> "Arguments":
# dataclasses.replace could be used here, but it is less
# type safe so for now I've opted to type everything out