Many symintifications (#87604)

Adds
expand_inplace
conv conv_double_backward
convolution
adaptive_avg_pool2d_symint
_embedding_bag_backward_symint
cudnn_grid_sampler
cuda 32 bit indexing
nll_loss / nll_loss_2d
tensor split
pooling same mode
cudnn_is_acceptable
storage nbytes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87604
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h
index 7798946..786cbf1 100644
--- a/aten/src/ATen/ExpandUtils.h
+++ b/aten/src/ATen/ExpandUtils.h
@@ -94,10 +94,11 @@
 inline c10::MaybeOwned<Tensor> expand_inplace(
     const Tensor& tensor,
     const Tensor& to_expand) {
-  if (tensor.sizes().equals(to_expand.sizes())) {
+  if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
     return c10::MaybeOwned<Tensor>::borrowed(to_expand);
   }
-  return c10::MaybeOwned<Tensor>::owned(to_expand.expand(tensor.sizes()));
+  return c10::MaybeOwned<Tensor>::owned(
+      to_expand.expand_symint(tensor.sym_sizes()));
 }
 
 inline c10::MaybeOwned<Tensor> expand_inplace(
diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h
index 08a14f2..0ecd445 100644
--- a/aten/src/ATen/core/TensorBase.h
+++ b/aten/src/ATen/core/TensorBase.h
@@ -956,10 +956,20 @@
 IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
 
 template <typename T, typename = enable_if_symint<T>>
+c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
+template <typename T, typename = enable_if_int<T>>
+int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
+
+template <typename T, typename = enable_if_symint<T>>
 c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
 template <typename T, typename = enable_if_int<T>>
 IntArrayRef strides(const TensorBase& t) { return t.strides(); }
 
+template <typename T, typename = enable_if_symint<T>>
+c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
+template <typename T, typename = enable_if_int<T>>
+int64_t numel(const TensorBase& t) { return t.numel(); }
+
 } // namespace symint
 
 } // namespace at
diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp
index 0640af3..79523ed 100644
--- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp
+++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp
@@ -17,7 +17,7 @@
 // we do not support batch_group_count (which is needed for convolution backwards).
 // Instead, there's a convolution_backward op that needs a batching rule.
 std::tuple<Tensor,optional<int64_t>>
-convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tensor& rhs, optional<int64_t> rhs_bdim, const optional<Tensor>& bias, optional<int64_t> bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) {
+convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tensor& rhs, optional<int64_t> rhs_bdim, const optional<Tensor>& bias, optional<int64_t> bias_bdim, IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, int64_t groups) {
   DimVector lhs_spec(stride.size() + 2);
   std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
   DimVector rhs_spec = lhs_spec;
@@ -42,13 +42,13 @@
   std::tuple<Tensor, optional<int64_t>> result;
   if (lhs_bdim && !rhs_bdim) {
     auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
-    auto out = at::convolution(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
+    auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
     out = reshape_dim_outof(out_spec[0], lhs.sizes()[*lhs_bdim], out);
     result = std::make_tuple(out, out_spec[0]);
   } else if (!lhs_bdim && rhs_bdim) {
     if (groups == 1) {
       auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
-      auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
+      auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
       out = reshape_dim_outof(out_spec[1], rhs.size(*rhs_bdim), out);
       result = std::make_tuple(out, out_spec[1]);
     } else {
@@ -62,7 +62,7 @@
         // BIOHW -> I(BO)HW
         auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs);
         // NIHW, I(BO)HW -> N(GBO)HW
-        auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
+        auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
         // N(GBO)HW -> NG(BO)HW
         out = reshape_dim_outof(1, groups, out);
         // NG(BO)HW -> NGBOHW
@@ -84,7 +84,7 @@
         // G(BO)IHW -> (GBO)IHW
         new_w = reshape_dim_into(0, 0, new_w);
         // N(GI)HW, (GBO)IHW -> N(GBO)HW
-        auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
+        auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
         // N(GBO)HW -> NG(BO)HW
         out = reshape_dim_outof(1, groups, out);
         // NG(BO)HW -> NGBOHW
@@ -99,11 +99,11 @@
     groups *= lhs.sizes()[*lhs_bdim];
     auto dim_with_groups = transposed ? 1 : 0;
     auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
-    auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
+    auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
     out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out);
     result = std::make_tuple(out, out_spec[1]);
   } else {
-    result = std::make_tuple(at::convolution(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt);
+    result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt);
   }
   if (separate_bias) {
     auto A = std::get<0>(result);
@@ -244,8 +244,8 @@
     const Tensor& grad_output, optional<int64_t> grad_output_bdim,
     const Tensor& input, optional<int64_t> input_bdim,
     const Tensor& weight, optional<int64_t> weight_bdim,
-    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
-    IntArrayRef output_padding, int64_t groups) {
+    IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed,
+    c10::SymIntArrayRef output_padding, int64_t groups) {
   const std::array<bool, 3> mask = {true, false, false};
   if (grad_output_bdim && weight_bdim) {
     // regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
@@ -254,7 +254,7 @@
     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
     const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
     auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output_, dummy_input, weight_, nullopt, stride, padding,
         dilation, transposed, output_padding, groups * batch_size, mask);
     const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
@@ -265,7 +265,7 @@
     const auto batch_size = grad_output.size(*grad_output_bdim);
     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
     auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output_, dummy_input, weight, nullopt, stride, padding,
         dilation, transposed, output_padding, groups, mask);
     const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
@@ -278,7 +278,7 @@
       const auto in_ch_dim = transposed ? 0 : 1;
       const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
       auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
-      const auto result = at::convolution_backward(
+      const auto result = at::convolution_backward_symint(
           grad_output, dummy_input, weight_, nullopt, stride, padding,
           dilation, transposed, output_padding, groups, mask);
       const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
@@ -289,7 +289,7 @@
       // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
       const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
       auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
-      const auto result = at::convolution_backward(
+      const auto result = at::convolution_backward_symint(
           grad_output, dummy_input, weight_, nullopt, stride, padding,
           dilation, transposed, output_padding, groups, mask);
       grad_input = std::get<0>(result); // N(GBI)
@@ -300,7 +300,7 @@
       weight_ = weight_.transpose(0, 1);                       // GBIO
       weight_ = weight_.flatten(0, 2);                         // (GBI)O
       const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
-      const auto result = at::convolution_backward(
+      const auto result = at::convolution_backward_symint(
           grad_output, dummy_input, weight_, nullopt, stride, padding,
           dilation, transposed, output_padding, groups, mask);
       grad_input = std::get<0>(result); // N(GBI)
@@ -314,7 +314,7 @@
   } else {
     TORCH_INTERNAL_ASSERT(input_bdim);
     const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output, dummy_input, weight, nullopt, stride, padding,
         dilation, transposed, output_padding, groups, mask);
     return std::make_tuple(std::get<0>(result), nullopt);
@@ -325,8 +325,8 @@
     const Tensor& grad_output, optional<int64_t> grad_output_bdim,
     const Tensor& input, optional<int64_t> input_bdim,
     const Tensor& weight, optional<int64_t> weight_bdim,
-    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
-    IntArrayRef output_padding, int64_t groups) {
+    IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed,
+    c10::SymIntArrayRef output_padding, int64_t groups) {
   const std::array<bool, 3> mask = {false, true, false};
   if (grad_output_bdim && input_bdim) {
     // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
@@ -334,7 +334,7 @@
     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
     const auto input_ = reshape_dim_into(*input_bdim, 1, input);
     const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output_, input_, dummy_weight, nullopt, stride, padding,
         dilation, transposed, output_padding, groups * batch_size, mask);
     auto grad_weight = std::get<1>(result);
@@ -348,7 +348,7 @@
       const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
       const auto out_ch_dim = transposed ? 1 : 0;
       const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
-      const auto result = at::convolution_backward(
+      const auto result = at::convolution_backward_symint(
           grad_output_, input, dummy_weight, nullopt, stride, padding,
           dilation, transposed, output_padding, groups, mask);
       auto grad_weight = std::get<1>(result);
@@ -362,7 +362,7 @@
       if (!transposed) {
         // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
         const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
-        const auto result = at::convolution_backward(
+        const auto result = at::convolution_backward_symint(
             grad_output_, input, dummy_weight, nullopt, stride, padding,
             dilation, transposed, output_padding, groups, mask);
         auto grad_weight = std::get<1>(result);
@@ -373,7 +373,7 @@
       } else {
         // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
         const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
-        const auto result = at::convolution_backward(
+        const auto result = at::convolution_backward_symint(
             grad_output_, input, dummy_weight, nullopt, stride, padding,
             dilation, transposed, output_padding, groups, mask);
         auto grad_weight = std::get<1>(result);
@@ -389,7 +389,7 @@
       const auto input_ = reshape_dim_into(*input_bdim, 1, input);
       const auto in_ch_dim = transposed ? 0 : 1;
       const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
-      const auto result = at::convolution_backward(
+      const auto result = at::convolution_backward_symint(
           grad_output, input_, dummy_weight, nullopt, stride, padding,
           dilation, transposed, output_padding, groups, mask);
       auto grad_weight = std::get<1>(result);
@@ -403,7 +403,7 @@
       if (!transposed) {
         // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
         const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
-        const auto result = at::convolution_backward(
+        const auto result = at::convolution_backward_symint(
             grad_output, input_, dummy_weight, nullopt, stride, padding,
             dilation, transposed, output_padding, groups, mask);
         auto grad_weight = std::get<1>(result);
@@ -412,7 +412,7 @@
       } else {
         // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
         const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
-        const auto result = at::convolution_backward(
+        const auto result = at::convolution_backward_symint(
             grad_output, input_, dummy_weight, nullopt, stride, padding,
             dilation, transposed, output_padding, groups, mask);
         auto grad_weight = std::get<1>(result);
@@ -425,7 +425,7 @@
   } else {
     TORCH_INTERNAL_ASSERT(weight_bdim);
     const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output, input, dummy_weight, nullopt, stride, padding,
         dilation, transposed, output_padding, groups, mask);
     return std::make_tuple(std::get<1>(result), nullopt);
@@ -436,8 +436,8 @@
 std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
     const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
     const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
-    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
-    IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
+    IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed,
+    c10::SymIntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
   const auto maybe_layer = maybeCurrentDynamicLayer();
   TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
   int64_t cur_level = maybe_layer->layerId();
@@ -487,7 +487,7 @@
     const auto batch_size = weight.size(*weight_bdim);
     input = reshape_dim_into(*input_bdim, 1, input);
     weight = reshape_dim_into(*weight_bdim, 0, weight);
-    const auto result = at::convolution_backward(
+    const auto result = at::convolution_backward_symint(
         grad_output, input, weight, nullopt, stride, padding, dilation,
         transposed, output_padding, batch_size * groups, output_mask);
     // N(BI), (BO)I -> NBI, BOI
diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
index f1108ba..24a1c4a 100644
--- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
+++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
@@ -242,7 +242,7 @@
   OP_DECOMPOSE2(where, ScalarSelf);
   OP_DECOMPOSE(orgqr);
   OP_DECOMPOSE2(unflatten, int);
-  OP_DECOMPOSE(_convolution_double_backward);
+  m.impl("_convolution_double_backward", native::_convolution_double_backward);
   OP_DECOMPOSE(conv_transpose1d);
   OP_DECOMPOSE2(conv_transpose2d, input);
   OP_DECOMPOSE2(conv_transpose3d, input);
diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp
index 40b05d7..b612ef0 100644
--- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp
+++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -130,9 +130,9 @@
       Tensor out = input.mean({-1, -2}, /* keepdim = */ true);
       if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
         // assert ndim == 4, since ndim = 3 doesn't give channels_last
-        const int n = input.size(0);
-        const int c = input.size(1);
-        out.as_strided_({n, c, 1, 1}, {c, 1, c, c});
+        const auto n = input.sym_size(0);
+        const auto c = input.sym_size(1);
+        out.as_strided__symint({n, c, 1, 1}, {c, 1, c, c});
       }
       return out;
     } else {
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 4d68f23..64f6d14 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -910,8 +910,8 @@
   auto k = weight.dim();
   TORCH_CHECK(k > 2, "weight should have at least three dimensions");
   auto dim = static_cast<size_t>(k - 2);
-  auto weight_sizes = weight.sizes();
-  auto input_sizes = input.sizes();
+  auto weight_sizes = weight.sym_sizes();
+  auto input_sizes = input.sym_sizes();
   TORCH_CHECK(k == input.dim(),
               "Expected ", k, "-dimensional input for ",
               k, "-dimensional weight", weight_sizes, ", but got ",
@@ -926,7 +926,7 @@
   }
 
   // Calculate the correct padding
-  DimVector padding_l, padding_r;
+  SymDimVector padding_l, padding_r;
   bool symmetric_padding = true;
   for (auto i: c10::irange(dim)) {
     auto s = stride.size() == 1 ? stride[0] : stride[i];
@@ -942,14 +942,14 @@
 
   if (symmetric_padding) {
     // All backends handle symmetric padding natively
-    DimVector output_padding(static_cast<size_t>(dim));
-    return at::convolution(input, weight, bias, stride, padding_l, dilation,
+    SymDimVector output_padding(static_cast<size_t>(dim));
+    return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
                                false, output_padding, groups);
   }
 
   TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
                   " require a zero-padded copy of the input be created");
-  SmallVector<int64_t, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
+  SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
   for (auto i: c10::irange(dim)) {
     // Apply padding by the difference, leaving only a symmetric padding
     auto delta_pad = padding_r[i] - padding_l[i];
@@ -961,10 +961,10 @@
       padding_l[i] = padding_r[i];
     }
   }
-  auto padded_input = at::constant_pad_nd(input, pad_nd, 0);
-  DimVector output_padding(static_cast<size_t>(dim));
-  return at::convolution(padded_input, weight, bias, stride, padding_l,
-                         dilation, false, output_padding, groups);
+  auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
+  SymDimVector output_padding(static_cast<size_t>(dim));
+  return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
+                                dilation, false, output_padding, groups);
 }
 
 Tensor _convolution_mode(
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 7d4a89d..2140494 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -1307,7 +1307,7 @@
   checkContiguous("embedding_bag", offsets_arg);
 
   Tensor offset2bag_;
-  if (indices.numel() != 0 && offset2bag.numel() == 0) {
+  if (indices.sym_numel() != 0 && offset2bag.sym_numel() == 0) {
     offset2bag_ = offsets.new_zeros(
       {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
 
diff --git a/aten/src/ATen/native/GridSamplerUtils.h b/aten/src/ATen/native/GridSamplerUtils.h
index 0b6f29d..7c22fed 100644
--- a/aten/src/ATen/native/GridSamplerUtils.h
+++ b/aten/src/ATen/native/GridSamplerUtils.h
@@ -101,7 +101,7 @@
     at::native::canUse32BitIndexMath(input) &&
     at::native::canUse32BitIndexMath(grid) &&
     input.dim() == 4 &&
-    input.size(1) <= 1024);
+    input.sym_size(1) <= 1024);
 }
 
 } // anonymous namespace
diff --git a/aten/src/ATen/native/IndexingUtils.cpp b/aten/src/ATen/native/IndexingUtils.cpp
index c5f5ff6..2dba197 100644
--- a/aten/src/ATen/native/IndexingUtils.cpp
+++ b/aten/src/ATen/native/IndexingUtils.cpp
@@ -4,7 +4,7 @@
 namespace at { namespace native {
 
 bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) {
-  int64_t elements = t.numel();
+  auto elements = t.sym_numel();
   if (elements >= max_elem) {
     return false;
   }
@@ -12,16 +12,16 @@
     return max_elem > 0;
   }
 
-  int64_t offset = 0;
-  int64_t linearId = elements - 1;
+  c10::SymInt offset = 0;
+  auto linearId = elements - 1;
 
   // NOTE: Assumes all strides are positive, which is true for now
   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
   for (int i = t.dim() - 1; i >= 0; --i) {
-    int64_t curDimIndex = linearId % t.size(i);
-    int64_t curDimOffset = curDimIndex * t.stride(i);
+    auto curDimIndex = linearId % t.sym_size(i);
+    auto curDimOffset = curDimIndex * t.sym_stride(i);
     offset += curDimOffset;
-    linearId /= t.size(i);
+    linearId /= t.sym_size(i);
   }
 
   if (offset >= max_elem) {
diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp
index 8e5864b..28fc605 100644
--- a/aten/src/ATen/native/LossNLL.cpp
+++ b/aten/src/ATen/native/LossNLL.cpp
@@ -656,7 +656,7 @@
   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
   const Tensor& weight = *weight_maybe_owned;
 
-  return std::get<0>(at::nll_loss_forward(self, target, weight, reduction, ignore_index));
+  return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index));
 }
 
 Tensor nll_loss_nd_symint(
diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp
index ab7c084..aee22ce 100644
--- a/aten/src/ATen/native/LossNLL2d.cpp
+++ b/aten/src/ATen/native/LossNLL2d.cpp
@@ -498,7 +498,7 @@
   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
   const Tensor& weight = *weight_maybe_owned;
 
-  return std::get<0>(at::nll_loss2d_forward(self, target, weight, reduction, ignore_index));
+  return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, ignore_index));
 }
 
 } // namespace native
diff --git a/aten/src/ATen/native/NonSymbolicBC.h b/aten/src/ATen/native/NonSymbolicBC.h
index e7d31ae..f57c868 100644
--- a/aten/src/ATen/native/NonSymbolicBC.h
+++ b/aten/src/ATen/native/NonSymbolicBC.h
@@ -22,4 +22,5 @@
 TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
 TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
 TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
+TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
 }}
diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h
index cf5b45b..0ff4490 100644
--- a/aten/src/ATen/native/Pool.h
+++ b/aten/src/ATen/native/Pool.h
@@ -67,17 +67,18 @@
         inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
 }
 
-inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
-    int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
+template <typename T>
+std::pair<T, T> _pooling_same_mode_padding_lr(
+    T inputSize, T kernelSize, int64_t stride, int64_t dilation) {
   // NOTE: with strides, the output shape is ceil(inputSize/stride)
-  auto total_padding = dilation * (kernelSize - 1);
+  auto total_padding = T(dilation) * (kernelSize - 1);
 
   // Prefer symmetric padding if possible
   if (stride > 2 && (total_padding % 2 == 1)) {
     // The floor in the output size calculation gives us a little wiggle room
     auto wiggle_room = inputSize % stride - 1;
     if (wiggle_room > 0) {
-      --total_padding;
+      total_padding = total_padding - 1;
     }
   }
 
@@ -85,6 +86,15 @@
   return {left, total_padding - left};
 }
 
+inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
+    int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
+  return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
+}
+
+inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
+    c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) {
+  return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
+}
 
 // AveragePool2d/DilatedMaxPool2d (forward)
 static inline void
diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp
index 6a703cb..e37dbf5 100644
--- a/aten/src/ATen/native/TensorProperties.cpp
+++ b/aten/src/ATen/native/TensorProperties.cpp
@@ -69,7 +69,7 @@
   // tensors. Maybe some cuDNN functions actually support empty tensors, but
   // native/THNN kernels shouldn't be much slower because the output is also
   // likely empty.
-  if (self.numel() == 0) return false;
+  if (self.sym_numel() == 0) return false;
   // NB: In the old Python code, there was also a test to see if the
   // cuDNN library was actually dynamically linked or not.  I'm not
   // sure if we can actually test this.
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index d251135..2051cda 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -917,9 +917,12 @@
   }
 }
 
-std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t dim) {
+std::vector<Tensor> tensor_split_sections_symint(const Tensor& self, c10::SymInt sym_sections, int64_t dim) {
   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
   int64_t dim_ = maybe_wrap_dim(dim, self.dim());
+  // NB: intentional, sections specifies number of output tensors, which
+  // cannot be polymorphic
+  int64_t sections = sym_sections.guard_int(__FILE__, __LINE__);
   TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
   const auto dim_size = self.sym_size(dim_);
   std::vector<Tensor> splits(sections);
@@ -934,21 +937,30 @@
   return splits;
 }
 
-std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
+template <typename T>
+std::vector<Tensor> _tensor_split_indices(const Tensor& self, ArrayRef<T> indices, int64_t dim) {
   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
   int64_t dim_ = maybe_wrap_dim(dim, self.dim());
   int64_t num_indices = indices.size();
   std::vector<Tensor> splits(num_indices + 1);
-  int64_t start_idx = 0;
+  T start_idx(0);
   for (const auto split_idx : c10::irange(num_indices)) {
-    int64_t end_idx = indices[split_idx];
-    splits[split_idx] = at::slice(self, dim_, start_idx, end_idx);
+    auto end_idx = indices[split_idx];
+    splits[split_idx] = at::symint::slice<T>(self, dim_, start_idx, end_idx);
     start_idx = end_idx;
   }
-  splits[num_indices] = at::slice(self, dim_, start_idx, self.size(dim_));
+  splits[num_indices] = at::symint::slice<T>(self, dim_, start_idx, at::symint::size<T>(self, dim_));
   return splits;
 }
 
+std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
+  return _tensor_split_indices(self, indices, dim);
+}
+
+std::vector<Tensor> tensor_split_indices_symint(const Tensor& self, SymIntArrayRef indices, int64_t dim) {
+  return _tensor_split_indices(self, indices, dim);
+}
+
 std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) {
   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
   auto split_device = tensor_indices_or_sections.device();
@@ -1174,8 +1186,8 @@
   return result;
 }
 
-const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
-  auto storage_offset = storage_offset_.value_or(self.storage_offset());
+const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymIntArrayRef stride, optional<c10::SymInt> storage_offset_) {
+  auto storage_offset = storage_offset_.value_or(self.sym_storage_offset());
   setStrided(self, size, stride, storage_offset);
   return self;
 }
diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp
index 5b38b02..c12d8d2 100644
--- a/aten/src/ATen/native/group_norm.cpp
+++ b/aten/src/ATen/native/group_norm.cpp
@@ -23,13 +23,15 @@
 #include <vector>
 
 namespace at {
+
 namespace native {
 
+template <typename T>
 void check_group_norm_inputs(
     const Tensor& input,
     const Tensor& weight,
     const Tensor& bias,
-    int64_t C,
+    T C,
     int64_t num_groups) {
   TORCH_CHECK(
       num_groups > 0,
@@ -43,14 +45,14 @@
       "num_groups=",
       num_groups);
   TORCH_CHECK(
-      !weight.defined() || (weight.dim() == 1 && weight.numel() == C),
+      !weight.defined() || (weight.dim() == 1 && at::symint::numel<T>(weight) == C),
       "Expected weight to be a vector of size equal to the number of ",
       "channels in input, but got weight of shape ",
       weight.sizes(),
       " and input of shape ",
       input.sizes());
   TORCH_CHECK(
-      !bias.defined() || (bias.dim() == 1 && bias.numel() == C),
+      !bias.defined() || (bias.dim() == 1 && at::symint::numel<T>(bias) == C),
       "Expected bias to be a vector of size equal to the number of ",
       "channels in input, but got bias of shape ",
       weight.sizes(),
@@ -171,13 +173,13 @@
   const Tensor& weight = *weight_maybe_owned;
   const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
 
-  const int64_t N = input.size(0);
-  const int64_t C = input.size(1);
+  const auto N = input.sym_size(0);
+  const auto C = input.sym_size(1);
   check_group_norm_inputs(input, weight, bias, C, num_groups);
 
-  const auto input_shape = input.sizes();
-  const int64_t HxW =
-      c10::multiply_integers(input_shape.cbegin() + 2, input_shape.cend());
+  const auto input_shape = input.sym_sizes();
+  const auto HxW =
+      c10::multiply_integers(input_shape.slice(2));
 
   const Tensor kEmpty;
   auto memory_format = input.suggest_memory_format();
@@ -185,10 +187,10 @@
       input.contiguous(memory_format) : input.contiguous();
   const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty;
   const auto& beta = bias.defined() ? bias.contiguous() : kEmpty;
-  TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
-  TORCH_CHECK(!beta.defined() || beta.numel() == C);
+  TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C);
+  TORCH_CHECK(!beta.defined() || beta.sym_numel() == C);
   return std::get<0>(
-      at::native_group_norm(X, gamma, beta, N, C, HxW, num_groups, eps));
+      at::native_group_norm_symint(X, gamma, beta, N, C, HxW, num_groups, eps));
 }
 
 DEFINE_DISPATCH(GroupNormKernel);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 69951d7..2922e2b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -815,7 +815,7 @@
   device_guard: False
   tags: inplace_view
   dispatch:
-    CompositeExplicitAutogradNonFunctional: as_strided_
+    CompositeExplicitAutogradNonFunctional: as_strided__symint
 
 - func: asin(Tensor self) -> Tensor
   device_check: NoCheck   # TensorIterator
@@ -1294,11 +1294,15 @@
     CompositeImplicitAutograd: chunk
     NestedTensorCPU, NestedTensorCUDA: chunk_nested_tensor
 
-- func: tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]
+- func: tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
   variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: tensor_split_sections_symint
 
-- func: tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]
+- func: tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
   variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: tensor_split_indices_symint
 
 - func: tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]
   variants: function, method
@@ -1465,13 +1469,13 @@
   variants: method
   manual_cpp_binding: True
 
-- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
+- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor
   dispatch:
     CompositeExplicitAutograd: convolution
   autogen: convolution.out
   tags: canonical
 
-- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
   dispatch:
     CompositeExplicitAutograd, CUDA: convolution_backward
   autogen: convolution_backward.out
@@ -1487,7 +1491,7 @@
     CompositeExplicitAutograd: convolution_backward_overrideable
   autogen: convolution_backward_overrideable.out
 
-- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
   dispatch:
     CompositeExplicitAutograd: _convolution
   autogen: _convolution.out
@@ -1496,7 +1500,7 @@
 
 - func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] dilation, int groups) -> Tensor
 
-- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
 
 - func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor
 
@@ -3561,7 +3565,7 @@
     MPS: mps_convolution_backward
   autogen: mps_convolution_backward.out
 
-- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
+- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups) -> Tensor
   dispatch:
     CompositeExplicitAutograd: mkldnn_convolution
   autogen: mkldnn_convolution.out
@@ -3576,17 +3580,17 @@
     CUDA: miopen_batch_norm_backward
   autogen: miopen_batch_norm_backward.out
 
-- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
   dispatch:
     CUDA: miopen_convolution
   autogen: miopen_convolution.out
 
-- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
   dispatch:
     CUDA: miopen_convolution_transpose
   autogen: miopen_convolution_transpose.out
 
-- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
   dispatch:
     CUDA: miopen_depthwise_convolution
   autogen: miopen_depthwise_convolution.out
@@ -3840,7 +3844,7 @@
 
 - func: _nnpack_available() -> bool
 
-- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor
+- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, int[2] stride=1) -> Tensor
   variants: function
   dispatch:
     CompositeExplicitAutograd: _nnpack_spatial_convolution
@@ -11470,24 +11474,24 @@
 # these are the same thing, but we give them different prefixes to
 # make the operational distinction clear.
 
-- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
   structured: True
   dispatch:
     CPU: slow_conv_transpose2d_structured_cpu
     CUDA: slow_conv_transpose2d_structured_cuda
 
-- func: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor
+- func: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1) -> Tensor
   python_module: nn
   structured_delegate: slow_conv_transpose2d.out
 
-- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
   dispatch:
     CPU: slow_conv_transpose3d_out_cpu
     CUDA: slow_conv_transpose3d_out_cuda
 
-- func: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor
+- func: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1) -> Tensor
   python_module: nn
   dispatch:
     CPU: slow_conv_transpose3d_cpu
@@ -11524,47 +11528,47 @@
     CUDA: slow_conv2d_backward_cuda
   autogen: _slow_conv2d_backward.output_mask_out
 
-- func: _conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+- func: _conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
   use_const_ref_for_mutable_tensors: True
   python_module: nn
   dispatch:
     CUDA: conv_depthwise2d_cuda_out
 
-- func: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor
+- func: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation) -> Tensor
   python_module: nn
   dispatch:
     CUDA: conv_depthwise2d_cuda
 
-- func: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor
+- func: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor
   python_module: nn
   dispatch:
     CUDA: conv_depthwise3d_cuda
   autogen: conv_depthwise3d.out
 
-- func: slow_conv3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+- func: slow_conv3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
 
-- func: slow_conv3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0) -> Tensor
+- func: slow_conv3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0) -> Tensor
   python_module: nn
 
-- func: slow_conv3d_forward.output(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+- func: slow_conv3d_forward.output(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
   python_module: nn
   dispatch:
     CPU: slow_conv3d_forward_out_cpu
 
-- func: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor
+- func: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding) -> Tensor
   python_module: nn
   dispatch:
     CPU: slow_conv3d_forward_cpu
 
-- func: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor
+- func: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, int[2] dilation=1) -> Tensor
   python_module: nn
   dispatch:
     CPU: slow_conv_dilated2d_cpu
     CUDA: slow_conv_dilated2d_cuda
   autogen: slow_conv_dilated2d.out
 
-- func: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor
+- func: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, int[3] dilation=1) -> Tensor
   python_module: nn
   dispatch:
     CPU: slow_conv_dilated3d_cpu
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index d406f2e..15e0e6a 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1128,8 +1128,6 @@
     skip('nn.functional.batch_norm', ''),  # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
     xfail('nn.functional.bilinear', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('nn.functional.binary_cross_entropy', ''),  # aten.fill_.Scalar - couldn't find symbolic meta funct...
-    xfail('nn.functional.conv1d', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
-    xfail('nn.functional.conv2d', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('nn.functional.cosine_embedding_loss', ''),  # Cannot call sizes() on tensor with symbolic sizes/st...
     xfail('nn.functional.cosine_similarity', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('nn.functional.cross_entropy', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -1230,7 +1228,6 @@
     xfail('trapezoid', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('trapz', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('triangular_solve', ''),  # aten.triangular_solve.default - couldn't find symbolic meta function/de...
-    xfail('unbind', ''),  # tensor_split() received an invalid combination of arguments - got (FakeTensor, torch...
     xfail('unflatten', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('var', ''),  # Cannot call numel() on tensor with symbolic sizes/strides
     xfail('var_mean', ''),  # Cannot call numel() on tensor with symbolic sizes/strides
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 1e72d5a..fae5536 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1248,8 +1248,6 @@
     xfail('nn.functional.avg_pool3d', ''),  # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.bilinear', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.binary_cross_entropy', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decom...
-    xfail('nn.functional.conv1d', ''),  # aten.convolution.default - couldn't find symbolic meta function/decomposition
-    xfail('nn.functional.conv2d', ''),  # aten.convolution.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.cosine_embedding_loss', ''),  # The underlying op of 'aten.stride' has no overload name '_schema'
     xfail('nn.functional.cosine_similarity', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.cross_entropy', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
@@ -1262,7 +1260,6 @@
     xfail('nn.functional.fractional_max_pool2d', ''),  # argument 'size' must be tuple of ints, but found element of t...
     xfail('nn.functional.fractional_max_pool3d', ''),  # argument 'size' must be tuple of ints, but found element of t...
     xfail('nn.functional.grid_sample', ''),  # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
-    xfail('nn.functional.group_norm', ''),  # 'torch._C.SymIntNode' and 'int'
     xfail('nn.functional.hinge_embedding_loss', ''),  # aten.empty_like.default - couldn't find symbolic meta function/deco...
     xfail('nn.functional.interpolate', 'area'),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.interpolate', 'bicubic'),  # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d...
@@ -1355,7 +1352,6 @@
     xfail('view_as_complex', ''),  # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
     xfail('view_as', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('vsplit', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('unbind', ''),  # aten.unbind.int - couldn't find symbolic meta function/decomposition
     xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
     xfail('unique', ''),  # aten._unique2.default - couldn't find symbolic meta function/decomposition
 }
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index c77f63e..6945dae 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2206,19 +2206,19 @@
   indices: non_differentiable
   result: auto_linear
 
-- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
+- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor
   input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
   result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups)
 
 # TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution.
 # Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context
 # by convolution_backward instead of being passed along from the forward pass.
-- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
   input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
   result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
 
-- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
-  grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
+- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
   result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false}))
   result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false}))
   result2: convolution_backward_jvp_grad_bias(grad_output_t, result2)
@@ -2229,10 +2229,10 @@
 - name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
   grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
 
-- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor
+- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor
+- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
 - name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor
@@ -2241,20 +2241,20 @@
 - name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
   grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask)
 
-- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor
+- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor
+- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor
+- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor
-  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, int[2] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor
-  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, int[3] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
 - name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
   self: im2col(grad, kernel_size, dilation, padding, stride)
@@ -2608,9 +2608,9 @@
 
 # nnpack
 
-- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor
+- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, int[2] stride=1) -> Tensor
   # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here.
-  input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+  input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
 #LSTM MPS
 - name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
@@ -2641,14 +2641,14 @@
 
 # miopen
 
-- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
   self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
-  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
-- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
-  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
 - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
   input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
@@ -2667,8 +2667,8 @@
   dropout_state: non_differentiable
 
 # mkldnn
-- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
-  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
 
 - name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
   self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py
index ebeaa21..79c594a 100644
--- a/tools/jit/gen_unboxing.py
+++ b/tools/jit/gen_unboxing.py
@@ -116,7 +116,9 @@
                 # from wrapping/unwrapping TensorOptios.
                 # However, we would look to include default args for schema parsing.
                 # Default args only show up in the nonfaithful C++ API,
-                arg_default = cpp.default_expr(arg.argument.default, arg.argument.type)
+                arg_default = cpp.default_expr(
+                    arg.argument.default, arg.argument.type, symint=False
+                )
                 if arg_default.startswith("{"):
                     arg_cpp = f"c10::IntArrayRef({arg_default})"
                 else:
diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp
index 2b74c8a..29f0f67 100644
--- a/torch/csrc/StorageMethods.cpp
+++ b/torch/csrc/StorageMethods.cpp
@@ -41,7 +41,7 @@
 static PyObject* THPStorage_nbytes(PyObject* _self, PyObject* noargs) {
   HANDLE_TH_ERRORS
   auto self = (THPStorage*)_self;
-  return THPUtils_packUInt64(self->cdata->nbytes());
+  return py::cast(self->cdata->sym_nbytes()).release().ptr();
   END_HANDLE_TH_ERRORS
 }
 
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 86b893b..3358d96 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -1098,15 +1098,15 @@
     const Tensor& bias_p,
     const Tensor& bias_t,
     IntArrayRef stride,
-    IntArrayRef padding,
+    at::SymIntArrayRef padding,
     IntArrayRef dilation,
     bool transposed,
-    IntArrayRef output_padding,
+    at::SymIntArrayRef output_padding,
     int64_t groups) {
   auto bias_t_opt =
       bias_t.defined() ? c10::optional<at::Tensor>(bias_t) : c10::nullopt;
   return (
-      at::convolution(
+      at::convolution_symint(
           input_t,
           weight_p,
           c10::nullopt,
@@ -1116,7 +1116,7 @@
           transposed,
           output_padding,
           groups) +
-      at::convolution(
+      at::convolution_symint(
           input_p,
           weight_t,
           bias_t_opt,
@@ -1136,10 +1136,10 @@
     const Tensor& bias_p,
     const Tensor& bias_t,
     IntArrayRef stride,
-    IntArrayRef padding,
+    at::SymIntArrayRef padding,
     IntArrayRef dilation,
     bool transposed,
-    IntArrayRef output_padding,
+    at::SymIntArrayRef output_padding,
     int64_t groups,
     bool benchmark,
     bool deterministic,
@@ -1148,7 +1148,7 @@
   auto bias_t_opt =
       bias_t.defined() ? c10::optional<at::Tensor>(bias_t) : c10::nullopt;
   return (
-      at::_convolution(
+      at::_convolution_symint(
           input_t,
           weight_p,
           c10::nullopt,
@@ -1162,7 +1162,7 @@
           deterministic,
           cudnn_enabled,
           allow_tf32) +
-      at::_convolution(
+      at::_convolution_symint(
           input_p,
           weight_t,
           bias_t_opt,
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 04416c2..4da8aa0 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -937,10 +937,10 @@
     const Tensor& bias_p,
     const Tensor& bias_t,
     IntArrayRef stride,
-    IntArrayRef padding,
+    at::SymIntArrayRef padding,
     IntArrayRef dilation,
     bool transposed,
-    IntArrayRef output_padding,
+    at::SymIntArrayRef output_padding,
     int64_t groups);
 
 Tensor _convolution_jvp(
@@ -951,10 +951,10 @@
     const Tensor& bias_p,
     const Tensor& bias_t,
     IntArrayRef stride,
-    IntArrayRef padding,
+    at::SymIntArrayRef padding,
     IntArrayRef dilation,
     bool transposed,
-    IntArrayRef output_padding,
+    at::SymIntArrayRef output_padding,
     int64_t groups,
     bool benchmark,
     bool deterministic,
diff --git a/torch/storage.py b/torch/storage.py
index 8e35973..6bfbab3 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -646,7 +646,9 @@
         return self._storage.device
 
     def size(self):
-        return len(self)
+        # NB: don't indirect through __len__, as that requires
+        # an int to be returned
+        return self.nbytes() // self.element_size()
 
     def pickle_storage_type(self):
         try:
diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py
index c3b12d0..4b00b53 100644
--- a/torchgen/api/cpp.py
+++ b/torchgen/api/cpp.py
@@ -314,7 +314,7 @@
 }
 
 # Convert a JIT default into C++ expression representing the default
-def default_expr(d: str, t: Type) -> str:
+def default_expr(d: str, t: Type, *, symint: bool) -> str:
     if d == "None" and str(t) == "Tensor?":
         return "{}"
     if isinstance(t, BaseType) and t.name is BaseTy.str:
@@ -342,11 +342,13 @@
         if d == "None":
             return "c10::nullopt"
 
-        return default_expr(d, t.elem)
+        return default_expr(d, t.elem, symint=symint)
 
     if isinstance(t, ListType):
         if d.startswith("[") and d.endswith("]"):
             return "{" + d[1:-1] + "}"
+        elif symint and d.isdigit() and str(t.elem) == "SymInt":
+            return f"c10::SymInt({d})"
         elif t.size is None:
             # NOTE: Sized lists can have scalar defaults
             raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
@@ -386,7 +388,7 @@
             binds = a.name
         default: Optional[str] = None
         if a.name not in cpp_no_default_args and a.default is not None:
-            default = default_expr(a.default, a.type)
+            default = default_expr(a.default, a.type, symint=symint)
         return [
             Binding(
                 nctype=argument_type(a, binds=binds, symint=symint),
diff --git a/torchgen/api/native.py b/torchgen/api/native.py
index b197a2a..7f8b3eb 100644
--- a/torchgen/api/native.py
+++ b/torchgen/api/native.py
@@ -95,7 +95,7 @@
     if isinstance(a, Argument):
         default: Optional[str] = None
         if should_default and a.default is not None:
-            default = cpp.default_expr(a.default, a.type)
+            default = cpp.default_expr(a.default, a.type, symint=symint)
         return [
             Binding(
                 nctype=argument_type(a, binds=a.name, symint=symint),
diff --git a/torchgen/api/python.py b/torchgen/api/python.py
index 96c006b..728ee4c 100644
--- a/torchgen/api/python.py
+++ b/torchgen/api/python.py
@@ -719,7 +719,9 @@
         name=a.name,
         type=a.type,
         # TODO: directly translate a.default to python default
-        default=str(pythonify_default(cpp.default_expr(a.default, a.type)))
+        default=str(
+            pythonify_default(cpp.default_expr(a.default, a.type, symint=False))
+        )
         if a.default is not None
         else None,
         default_init=None,
@@ -804,7 +806,7 @@
             a = getattr(topt_args, name)
             if a.default is None or a.default == "None":
                 return None
-            return cpp.default_expr(a.default, a.type)
+            return cpp.default_expr(a.default, a.type, symint=False)
 
         tensor_options_args.append(
             PythonArgument(
diff --git a/torchgen/gen.py b/torchgen/gen.py
index e537349..79970c9 100644
--- a/torchgen/gen.py
+++ b/torchgen/gen.py
@@ -1151,7 +1151,9 @@
         "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
     }
     if a.default is not None:
-        arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type))
+        arg["default"] = pythonify_default(
+            cpp.default_expr(a.default, a.type, symint=False)
+        )
     if a.name in kwarg_only_set:
         arg["kwarg_only"] = True
     if a.name in out_arg_set: