symintify unbind_backward and tensor_split (#86357)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86357
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 0d91e25..ac36151 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -743,14 +743,14 @@
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
- const auto dim_size = self.size(dim_);
+ const auto dim_size = self.sym_size(dim_);
std::vector<Tensor> splits(sections);
- int64_t min_split_size = dim_size / sections;
- int64_t num_splits_one_extra = dim_size % sections;
- int64_t start_idx = 0;
+ auto min_split_size = dim_size / sections;
+ auto num_splits_one_extra = dim_size % sections;
+ c10::SymInt start_idx = 0;
for (const auto split_idx : c10::irange(sections)) {
- int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;
- splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size);
+ auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
+ splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
start_idx += split_size;
}
return splits;
diff --git a/functorch/test/test_aotdispatch.py b/functorch/test/test_aotdispatch.py
index e1923ff..0b8b0fa 100644
--- a/functorch/test/test_aotdispatch.py
+++ b/functorch/test/test_aotdispatch.py
@@ -904,7 +904,6 @@
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol...
xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g...
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
- xfail('nn.functional.glu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.grid_sample', ''), # prims::arange() Expected a value of type 'number' for argument...
xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta...
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 73e1497..a0c7950 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1196,7 +1196,6 @@
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
- xfail('nn.functional.glu', ''), # aten.glu.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 05ae6e9..45c9eab 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -831,18 +831,19 @@
}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
- IntArrayRef sizes;
+ c10::SymIntArrayRef sizes;
at::TensorOptions o;
for (const auto& v : grads) {
if (v.defined()) {
- sizes = v.sizes();
+ sizes = v.sym_sizes();
o = static_cast<Tensor>(v).options();
break;
}
}
auto grads_tensors = fmap(grads, [&](const Variable& v) {
return (
- v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
+ v.defined() ? static_cast<Tensor>(v)
+ : at::zeros({}, o).expand_symint(sizes));
});
return at::stack(grads_tensors, dim);
}