symintify unbind_backward and tensor_split (#86357)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86357
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 0d91e25..ac36151 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -743,14 +743,14 @@
   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
   int64_t dim_ = maybe_wrap_dim(dim, self.dim());
   TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
-  const auto dim_size = self.size(dim_);
+  const auto dim_size = self.sym_size(dim_);
   std::vector<Tensor> splits(sections);
-  int64_t min_split_size = dim_size / sections;
-  int64_t num_splits_one_extra = dim_size % sections;
-  int64_t start_idx = 0;
+  auto min_split_size = dim_size / sections;
+  auto num_splits_one_extra = dim_size % sections;
+  c10::SymInt start_idx = 0;
   for (const auto split_idx : c10::irange(sections)) {
-    int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;
-    splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size);
+    auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
+    splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
     start_idx += split_size;
   }
   return splits;
diff --git a/functorch/test/test_aotdispatch.py b/functorch/test/test_aotdispatch.py
index e1923ff..0b8b0fa 100644
--- a/functorch/test/test_aotdispatch.py
+++ b/functorch/test/test_aotdispatch.py
@@ -904,7 +904,6 @@
     xfail('nn.functional.feature_alpha_dropout', 'with_train'),  # Cannot call numel() on tensor with symbol...
     xfail('nn.functional.fractional_max_pool2d', ''),  # rand() received an invalid combination of arguments - g...
     xfail('nn.functional.fractional_max_pool3d', ''),  # rand() received an invalid combination of arguments - g...
-    xfail('nn.functional.glu', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('nn.functional.grid_sample', ''),  # prims::arange() Expected a value of type 'number' for argument...
     xfail('nn.functional.group_norm', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('nn.functional.hinge_embedding_loss', ''),  # aten.zeros_like.default - couldn't find symbolic meta...
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 73e1497..a0c7950 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1196,7 +1196,6 @@
     xfail('nn.functional.feature_alpha_dropout', 'with_train'),  # Tensors of type TensorImpl do not have numel
     xfail('nn.functional.fractional_max_pool2d', ''),  # argument 'size' must be tuple of ints, but found element of t...
     xfail('nn.functional.fractional_max_pool3d', ''),  # argument 'size' must be tuple of ints, but found element of t...
-    xfail('nn.functional.glu', ''),  # aten.glu.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.grid_sample', ''),  # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
     xfail('nn.functional.group_norm', ''),  # 'torch._C.SymIntNode' and 'int'
     xfail('nn.functional.hinge_embedding_loss', ''),  # aten.empty_like.default - couldn't find symbolic meta function/deco...
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 05ae6e9..45c9eab 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -831,18 +831,19 @@
 }
 
 Tensor unbind_backward(const variable_list& grads, int64_t dim) {
-  IntArrayRef sizes;
+  c10::SymIntArrayRef sizes;
   at::TensorOptions o;
   for (const auto& v : grads) {
     if (v.defined()) {
-      sizes = v.sizes();
+      sizes = v.sym_sizes();
       o = static_cast<Tensor>(v).options();
       break;
     }
   }
   auto grads_tensors = fmap(grads, [&](const Variable& v) {
     return (
-        v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
+        v.defined() ? static_cast<Tensor>(v)
+                    : at::zeros({}, o).expand_symint(sizes));
   });
   return at::stack(grads_tensors, dim);
 }