Add support for 32KB multi_tensor_apply kernel arguments (#134373)

## Benchmark

On H100 SXM (HBM2e, 500W TDP), CUDA Toolkit=12.2, Driver Version=535.154.05, with [this script](https://gist.github.com/yifuwang/178c1f4bf951c5794ea79c04d90e44fa) (`torch._foreach_copy_`):

**Baseline**
```
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp0g_x4sys
device ms: 0.891, cpu ms: 7.200
memory bandwidth: 1457.727 GB/s
```

Single iteration trace:
<img width="1432" alt="image" src="https://github.com/user-attachments/assets/8ef54365-0265-4281-a0f0-d4c2f448300e">

**This PR**
```
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp3jqiugli
device ms: 0.683, cpu ms: 6.745
memory bandwidth: 1902.010 GB/s
```

Single iteration trace:
<img width="1074" alt="image" src="https://github.com/user-attachments/assets/e52acad1-d09b-492c-9611-6d69e339f3ac">

## Binary Size and Kernel Specialization
The binary size for `libtorch_cuda.so` increased 6MB (243MB -> 249MB).

```
// NOTE: [32KB kernel argument size support]
// 32KB kernel argument size support has three requirements:
// - CUDART_VERSION >= 12010
// - Driver version >= 530
// - GPU arch >= VOLTA
//
// Due to minor version compatibility, it possible for binaries built with
// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
// to build both 4KB and 32KB kernels and determine the appropriate kernel to
// dispatch at runtime.
//
// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
//
// - If CUDART_VERSION >= 12010:
//   - Host code:
//     - We always instantiate the launching stub for both 4KB and 32KB kernels.
//   - Device code:
//     - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
//     kernels.
//     - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
//     32KB kernel (formal parameter space overflowed). Thus, we only
//     instantiate a declaration for 32KB kernels. This is valid as long as the
//     declaration-only kernel is not launched.
//
// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
// GPU arch >= VOLTA.
//
// - TODO(yifu): once there's a CUDART version that is not compatible with any
// driver version below 530, we can determine at compile time to not compile
// the kernels for 4KB kernel argument size.
//
// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134373
Approved by: https://github.com/eqy, https://github.com/crcrpar, https://github.com/janeyx99
diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu
index 8c161ca..07cb84f 100644
--- a/aten/src/ATen/native/cuda/AmpKernels.cu
+++ b/aten/src/ATen/native/cuda/AmpKernels.cu
@@ -157,21 +157,24 @@
       using opmath_t = at::opmath_type<scalar_t>;
 
       // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
-      multi_tensor_apply<1>(tensor_lists,
-                            UnaryOpFunctor<scalar_t,
-                                           /* depth */ 1,
-                                           /* r_args_depth */ 1,
-                                           /* res_arg_index */ 0>(),
-                            [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
-                              // There is a slight asymmetry here with the TensorIterator kernel above.
-                              // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
-                              if (!isfinite_ensure_cuda_math(val)) {
-                                *found_inf_ptr = 1.f;
-                              }
-                              // Every thread accesses inv_scale, but it will hit in cache.
-                              const auto inv_scale_val = *inv_scale_ptr;
-                              return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
-                            });
+      DISPATCH_MULTI_TENSOR_APPLY([&]() {
+        multi_tensor_apply<1>(tensor_lists,
+                              UnaryOpFunctor<scalar_t,
+                                             /* depth */ 1,
+                                             /* r_args_depth */ 1,
+                                             /* res_arg_index */ 0,
+                                             large_kernel_arg>(),
+                              [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
+                                // There is a slight asymmetry here with the TensorIterator kernel above.
+                                // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
+                                if (!isfinite_ensure_cuda_math(val)) {
+                                  *found_inf_ptr = 1.f;
+                                }
+                                // Every thread accesses inv_scale, but it will hit in cache.
+                                const auto inv_scale_val = *inv_scale_ptr;
+                                return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
+                              });
+      });
     });
 }
 
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
index 533aa38..90c180c 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
@@ -41,15 +41,18 @@
   tensor_lists.emplace_back(std::move(vec_res));
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<3>(
-      tensor_lists,
-      BinaryOpListAlphaFunctor<
-          T,
-          /* depth */ 3,
-          /* r_args_depth */ 2,
-          /* res_arg_index */ 2>(),
-      Op<opmath_t>(),
-      alpha.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<3>(
+        tensor_lists,
+        BinaryOpListAlphaFunctor<
+            T,
+            /* depth */ 3,
+            /* r_args_depth */ 2,
+            /* res_arg_index */ 2,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        alpha.to<opmath_t>());
+  });
 
   return tensor_lists[2];
 }
@@ -64,15 +67,18 @@
   tensor_lists.emplace_back(tensors2.vec());
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<2>(
-      tensor_lists,
-      BinaryOpListAlphaFunctor<
-          T,
-          /* depth */ 2,
-          /* r_args_depth */ 2,
-          /* res_arg_index */ 0>(),
-      Op<opmath_t>(),
-      alpha.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<2>(
+        tensor_lists,
+        BinaryOpListAlphaFunctor<
+            T,
+            /* depth */ 2,
+            /* r_args_depth */ 2,
+            /* res_arg_index */ 0,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        alpha.to<opmath_t>());
+  });
   increment_version(tensors1);
 }
 
@@ -331,13 +337,15 @@
     typename src_t,
     int depth,
     int r_args_depth,
-    int res_arg_index>
+    int res_arg_index,
+    bool large_kernel_arg>
 struct CopyFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -420,14 +428,17 @@
         using opmath_t = at::opmath_type<scalar_t>;
         AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
           if constexpr (std::is_same_v<scalar_t, src_t>) {
-            multi_tensor_apply<2>(
-                tensor_lists,
-                UnaryOpFunctor<
-                    scalar_t,
-                    /* depth */ 2,
-                    /* r_args_depth */ 1,
-                    /* res_arg_index */ 1>(),
-                Copy<opmath_t, opmath_t>());
+            DISPATCH_MULTI_TENSOR_APPLY([&]() {
+              multi_tensor_apply<2>(
+                  tensor_lists,
+                  UnaryOpFunctor<
+                      scalar_t,
+                      /* depth */ 2,
+                      /* r_args_depth */ 1,
+                      /* res_arg_index */ 1,
+                      large_kernel_arg>(),
+                  Copy<opmath_t, opmath_t>());
+            });
           } else {
             // Ref:
             // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
@@ -435,15 +446,18 @@
               TORCH_WARN_ONCE(
                   "Casting complex values to real discards the imaginary part");
             }
-            multi_tensor_apply<2>(
-                tensor_lists,
-                CopyFunctor<
-                    scalar_t,
-                    src_t,
-                    /* depth */ 2,
-                    /* r_args_depth */ 1,
-                    /* res_arg_index */ 1>(),
-                Copy<scalar_t, src_t>());
+            DISPATCH_MULTI_TENSOR_APPLY([&]() {
+              multi_tensor_apply<2>(
+                  tensor_lists,
+                  CopyFunctor<
+                      scalar_t,
+                      src_t,
+                      /* depth */ 2,
+                      /* r_args_depth */ 1,
+                      /* res_arg_index */ 1,
+                      large_kernel_arg>(),
+                  Copy<scalar_t, src_t>());
+            });
           }
         });
       });
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
index 80d748d..1b0045d 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
@@ -36,15 +36,18 @@
   tensor_lists.emplace_back(std::move(vec_res));
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<2>(
-      tensor_lists,
-      BinaryOpScalarFunctor<
-          T,
-          /* depth */ 2,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 1>(),
-      Op<opmath_t>(),
-      scalar.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<2>(
+        tensor_lists,
+        BinaryOpScalarFunctor<
+            T,
+            /* depth */ 2,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 1,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        scalar.to<opmath_t>());
+  });
   return tensor_lists[1];
 }
 
@@ -54,15 +57,18 @@
   tensor_lists.emplace_back(tensors.vec());
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<1>(
-      tensor_lists,
-      BinaryOpScalarFunctor<
-          T,
-          /* depth */ 1,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 0>(),
-      Op<opmath_t>(),
-      scalar.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<1>(
+        tensor_lists,
+        BinaryOpScalarFunctor<
+            T,
+            /* depth */ 1,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 0,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        scalar.to<opmath_t>());
+  });
   increment_version(tensors);
 }
 
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
index dcb9318..c423006 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
@@ -36,16 +36,19 @@
   tensor_lists.emplace_back(vec_res);
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<2, opmath_t>(
-      tensor_lists,
-      scalars,
-      BinaryOpScalarListFunctor<
-          T,
-          /* depth */ 2,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 1>(),
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<2, opmath_t>(
+        tensor_lists,
+        scalars,
+        BinaryOpScalarListFunctor<
+            T,
+            /* depth */ 2,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 1,
+            large_kernel_arg>(),
 
-      Op<opmath_t>());
+        Op<opmath_t>());
+  });
   return tensor_lists[1];
 }
 
@@ -55,15 +58,18 @@
   tensor_lists.emplace_back(tensors.vec());
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<1, opmath_t>(
-      tensor_lists,
-      scalars,
-      BinaryOpScalarListFunctor<
-          T,
-          /* depth */ 1,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 0>(),
-      Op<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<1, opmath_t>(
+        tensor_lists,
+        scalars,
+        BinaryOpScalarListFunctor<
+            T,
+            /* depth */ 1,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 0,
+            large_kernel_arg>(),
+        Op<opmath_t>());
+  });
   increment_version(tensors);
 }
 
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
index ad5eeee..4163b30 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
@@ -46,16 +46,19 @@
   tensor_lists.emplace_back(std::move(vec_res));
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<2>(
-      tensor_lists,
-      BinaryOpScalarTensorFunctor<
-          T,
-          /* depth */ 2,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 1>(),
-      Op<opmath_t>(),
-      scalar.data_ptr<T>(),
-      alpha.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<2>(
+        tensor_lists,
+        BinaryOpScalarTensorFunctor<
+            T,
+            /* depth */ 2,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 1,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        scalar.data_ptr<T>(),
+        alpha.to<opmath_t>());
+  });
   return tensor_lists[1];
 }
 
@@ -81,16 +84,19 @@
   tensor_lists.emplace_back(tensors.vec());
 
   using opmath_t = at::opmath_type<T>;
-  multi_tensor_apply<1>(
-      tensor_lists,
-      BinaryOpScalarTensorFunctor<
-          T,
-          /* depth */ 1,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 0>(),
-      Op<opmath_t>(),
-      scalar.data_ptr<T>(),
-      alpha.to<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<1>(
+        tensor_lists,
+        BinaryOpScalarTensorFunctor<
+            T,
+            /* depth */ 1,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 0,
+            large_kernel_arg>(),
+        Op<opmath_t>(),
+        scalar.data_ptr<T>(),
+        alpha.to<opmath_t>());
+  });
   increment_version(tensors);
 }
 
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
index 55e4fd7..fb8d849 100644
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -18,10 +18,10 @@
 }
 
 // Initializes args and checks if all args are aligned
-template <int depth, typename T>
+template <int depth, typename T, bool large_kernel_arg>
 __device__ bool init_args(
     T** args,
-    TensorListMetadata<depth>& tl,
+    TensorListMetadata<depth, large_kernel_arg>& tl,
     const int64_t chunk_idx,
     const int64_t chunk_size,
     const int64_t tensor_loc) {
@@ -38,10 +38,10 @@
 }
 
 // Initializes args and checks if all args are aligned
-template <int depth, typename T, typename T2>
+template <int depth, typename T, typename T2, bool large_kernel_arg>
 __device__ bool init_args(
     T** args,
-    TensorListScalarListMetadata<T2, depth>& tl,
+    TensorListScalarListMetadata<T2, depth, large_kernel_arg>& tl,
     const int64_t chunk_idx,
     const int64_t chunk_size,
     const int64_t tensor_loc) {
@@ -57,10 +57,10 @@
   return all_aligned;
 }
 
-template <int depth, typename T>
+template <int depth, typename T, bool large_kernel_arg>
 __device__ bool init_args(
     T** args,
-    FusedOptimizerTensorListMetadata<depth>& tl,
+    FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
     const int64_t chunk_idx,
     const int64_t chunk_size,
     const int64_t tensor_loc) {
@@ -203,13 +203,19 @@
 //
 // Binary Functors
 //
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct BinaryOpScalarFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op,
       opmath_t scalar) {
     const int tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -227,13 +233,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct BinaryOpScalarListFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListScalarListMetadata<opmath_t, depth>& tl,
+      TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
       Op op) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -251,13 +263,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct BinaryOpListAlphaFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op,
       opmath_t alpha) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -303,13 +321,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct BinaryOpScalarTensorFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op,
       T* scalar,
       opmath_t alpha) {
@@ -361,11 +385,17 @@
 // Unary Functors
 //
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct ZeroFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<1>& tl) {
+      TensorListMetadata<1, large_kernel_arg>& tl) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
     auto n = tl.numel_for_tensor[tensor_loc];
@@ -401,13 +431,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct UnaryOpFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -453,13 +489,19 @@
 // Pointwise Functors
 //
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct PointwiseOpScalarFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op,
       opmath_t scalar) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -477,13 +519,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct PointwiseOpScalarListFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListScalarListMetadata<opmath_t, depth>& tl,
+      TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
       Op op) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -501,13 +549,14 @@
   }
 };
 
-template <typename T, int depth>
+template <typename T, int depth, bool large_kernel_arg>
 struct PointwiseOpListFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -552,13 +601,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct TernaryOpListFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op) {
     static_assert(depth == 3 || depth == 4, "");
     static_assert(depth >= r_args_depth, "");
@@ -606,13 +661,19 @@
   }
 };
 
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+    typename T,
+    int depth,
+    int r_args_depth,
+    int res_arg_index,
+    bool large_kernel_arg>
 struct TernaryOpScalarFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using opmath_t = at::opmath_type<T>;
   template <typename Op>
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       Op op,
       opmath_t alpha) {
     static_assert(depth == 2 || depth == 3, "");
diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
index 7a3276c..bad152c 100644
--- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
@@ -46,15 +46,18 @@
       "foreach_pointwise_op_cuda",
       [&]() {
         using opmath_t = at::opmath_type<scalar_t>;
-        multi_tensor_apply<4>(
-            tensor_lists,
-            PointwiseOpScalarFunctor<
-                scalar_t,
-                /* depth */ 4,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 3>(),
-            Op<opmath_t>(),
-            scalar.to<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<4>(
+              tensor_lists,
+              PointwiseOpScalarFunctor<
+                  scalar_t,
+                  /* depth */ 4,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 3,
+                  large_kernel_arg>(),
+              Op<opmath_t>(),
+              scalar.to<opmath_t>());
+        });
       });
 
   return tensor_lists[3];
@@ -78,15 +81,18 @@
       "foreach_pointwise_op__cuda",
       [&]() {
         using opmath_t = at::opmath_type<scalar_t>;
-        multi_tensor_apply<3>(
-            tensor_lists,
-            PointwiseOpScalarFunctor<
-                scalar_t,
-                /* depth */ 3,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 0>(),
-            Op<opmath_t>(),
-            scalar.to<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3>(
+              tensor_lists,
+              PointwiseOpScalarFunctor<
+                  scalar_t,
+                  /* depth */ 3,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 0,
+                  large_kernel_arg>(),
+              Op<opmath_t>(),
+              scalar.to<opmath_t>());
+        });
       });
   increment_version(input);
 }
@@ -110,15 +116,18 @@
       "foreach_pointwise_op__cuda",
       [&]() {
         using opmath_t = at::opmath_type<scalar_t>;
-        multi_tensor_apply<3, opmath_t>(
-            tensor_lists,
-            scalars,
-            PointwiseOpScalarListFunctor<
-                scalar_t,
-                /* depth */ 3,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 0>(),
-            Op<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3, opmath_t>(
+              tensor_lists,
+              scalars,
+              PointwiseOpScalarListFunctor<
+                  scalar_t,
+                  /* depth */ 3,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 0,
+                  large_kernel_arg>(),
+              Op<opmath_t>());
+        });
       });
   increment_version(input);
 }
@@ -149,15 +158,18 @@
       "foreach_pointwise_op_cuda",
       [&]() {
         using opmath_t = at::opmath_type<scalar_t>;
-        multi_tensor_apply<4, opmath_t>(
-            tensor_lists,
-            scalars,
-            PointwiseOpScalarListFunctor<
-                scalar_t,
-                /* depth */ 4,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 3>(),
-            Op<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<4, opmath_t>(
+              tensor_lists,
+              scalars,
+              PointwiseOpScalarListFunctor<
+                  scalar_t,
+                  /* depth */ 4,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 3,
+                  large_kernel_arg>(),
+              Op<opmath_t>());
+        });
       });
 
   return tensor_lists[3];
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index 61793fd..1ce18b7 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -50,11 +50,13 @@
     typename T,
     int depth = 1,
     int r_args_depth = 1,
-    int res_arg_index = 0>
+    int res_arg_index = 0,
+    bool large_kernel_arg = false>
 struct LpMaxFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       T* output_per_tensor_ptr,
       const int max_chunks_per_tensor) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -178,11 +180,13 @@
       tensor_lists[0][0].scalar_type(),
       "foreach_tensor_max_cuda_scalar_type",
       [&]() {
-        multi_tensor_apply<1>(
-            tensor_lists,
-            LpMaxFunctor<scalar_t>(),
-            output_per_tensor.mutable_data_ptr<scalar_t>(),
-            max_chunks_per_tensor);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<1>(
+              tensor_lists,
+              LpMaxFunctor<scalar_t, 1, 1, 0, large_kernel_arg>(),
+              output_per_tensor.mutable_data_ptr<scalar_t>(),
+              max_chunks_per_tensor);
+        });
 
         C10_CUDA_KERNEL_LAUNCH_CHECK();
         const at::cuda::OptionalCUDAGuard device_guard(
@@ -239,12 +243,14 @@
     typename out_t,
     int depth = 1,
     int r_args_depth = 1,
-    int res_arg_index = 0>
+    int res_arg_index = 0,
+    bool large_kernel_arg = false>
 struct LpNormFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   using out_opmath_t = typename at::opmath_type<out_t>;
   __device__ __forceinline__ void operator()(
       int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       out_opmath_t* output_per_tensor_ptr,
       const int max_chunks_per_tensor) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -476,23 +482,50 @@
             output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
               using out_opmath_t = typename at::opmath_type<out_t>;
               if (p == static_cast<double>(1)) {
-                multi_tensor_apply<1>(
-                    tensor_lists,
-                    LpNormFunctor<scalar_t, NormType::L1, out_t>(),
-                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
-                    max_chunks_per_tensor);
+                DISPATCH_MULTI_TENSOR_APPLY([&]() {
+                  multi_tensor_apply<1>(
+                      tensor_lists,
+                      LpNormFunctor<
+                          scalar_t,
+                          NormType::L1,
+                          out_t,
+                          1,
+                          1,
+                          0,
+                          large_kernel_arg>(),
+                      output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+                      max_chunks_per_tensor);
+                });
               } else if (p == static_cast<double>(2)) {
-                multi_tensor_apply<1>(
-                    tensor_lists,
-                    LpNormFunctor<scalar_t, NormType::L2, out_t>(),
-                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
-                    max_chunks_per_tensor);
+                DISPATCH_MULTI_TENSOR_APPLY([&]() {
+                  multi_tensor_apply<1>(
+                      tensor_lists,
+                      LpNormFunctor<
+                          scalar_t,
+                          NormType::L2,
+                          out_t,
+                          1,
+                          1,
+                          0,
+                          large_kernel_arg>(),
+                      output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+                      max_chunks_per_tensor);
+                });
               } else if (p == std::numeric_limits<double>::infinity()) {
-                multi_tensor_apply<1>(
-                    tensor_lists,
-                    LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
-                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
-                    max_chunks_per_tensor);
+                DISPATCH_MULTI_TENSOR_APPLY([&]() {
+                  multi_tensor_apply<1>(
+                      tensor_lists,
+                      LpNormFunctor<
+                          scalar_t,
+                          NormType::LInf,
+                          out_t,
+                          1,
+                          1,
+                          0,
+                          large_kernel_arg>(),
+                      output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+                      max_chunks_per_tensor);
+                });
               }
               C10_CUDA_KERNEL_LAUNCH_CHECK();
               const at::cuda::OptionalCUDAGuard device_guard(
diff --git a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
index e13f201..e73862d 100644
--- a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
@@ -46,14 +46,17 @@
       "foreach_tensor_lerp_ternary_cuda",
       [&]() {
         using opmath_t = typename at::opmath_type<scalar_t>;
-        multi_tensor_apply<4>(
-            tensor_lists,
-            TernaryOpListFunctor<
-                scalar_t,
-                /* depth */ 4,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 3>(),
-            LerpFunctor<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<4>(
+              tensor_lists,
+              TernaryOpListFunctor<
+                  scalar_t,
+                  /* depth */ 4,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 3,
+                  large_kernel_arg>(),
+              LerpFunctor<opmath_t>());
+        });
       });
 
   return tensor_lists[3];
@@ -77,14 +80,17 @@
       "foreach_tensor_lerp_ternary_cuda_",
       [&]() {
         using opmath_t = typename at::opmath_type<scalar_t>;
-        multi_tensor_apply<3>(
-            tensor_lists,
-            TernaryOpListFunctor<
-                scalar_t,
-                /* depth */ 3,
-                /* r_args_depth */ 3,
-                /* res_arg_index */ 0>(),
-            LerpFunctor<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3>(
+              tensor_lists,
+              TernaryOpListFunctor<
+                  scalar_t,
+                  /* depth */ 3,
+                  /* r_args_depth */ 3,
+                  /* res_arg_index */ 0,
+                  large_kernel_arg>(),
+              LerpFunctor<opmath_t>());
+        });
       });
   increment_version(tensors1);
 }
@@ -113,15 +119,18 @@
       "foreach_tensor_lerp_scalar_cuda",
       [&]() {
         using opmath_t = typename at::opmath_type<scalar_t>;
-        multi_tensor_apply<3>(
-            tensor_lists,
-            TernaryOpScalarFunctor<
-                scalar_t,
-                /* depth */ 3,
-                /* r_args_depth */ 2,
-                /* res_arg_index */ 2>(),
-            LerpFunctor<opmath_t>(),
-            weight.to<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3>(
+              tensor_lists,
+              TernaryOpScalarFunctor<
+                  scalar_t,
+                  /* depth */ 3,
+                  /* r_args_depth */ 2,
+                  /* res_arg_index */ 2,
+                  large_kernel_arg>(),
+              LerpFunctor<opmath_t>(),
+              weight.to<opmath_t>());
+        });
       });
 
   return tensor_lists[2];
@@ -145,15 +154,18 @@
       "foreach_tensor_lerp_scalar_cuda_",
       [&]() {
         using opmath_t = typename at::opmath_type<scalar_t>;
-        multi_tensor_apply<2>(
-            tensor_lists,
-            TernaryOpScalarFunctor<
-                scalar_t,
-                /* depth */ 2,
-                /* r_args_depth */ 2,
-                /* res_arg_index */ 0>(),
-            LerpFunctor<opmath_t>(),
-            weight.to<opmath_t>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<2>(
+              tensor_lists,
+              TernaryOpScalarFunctor<
+                  scalar_t,
+                  /* depth */ 2,
+                  /* r_args_depth */ 2,
+                  /* res_arg_index */ 0,
+                  large_kernel_arg>(),
+              LerpFunctor<opmath_t>(),
+              weight.to<opmath_t>());
+        });
       });
 }
 } // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
index 1a969cf..fa333b4 100644
--- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
@@ -56,14 +56,17 @@
   tensor_lists.emplace_back(std::move(vec_res));
 
   using opmath_t = typename at::opmath_type<scalar_t>;
-  multi_tensor_apply<2>(
-      tensor_lists,
-      UnaryOpFunctor<
-          scalar_t,
-          /* depth */ 2,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 1>(),
-      Op<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<2>(
+        tensor_lists,
+        UnaryOpFunctor<
+            scalar_t,
+            /* depth */ 2,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 1,
+            large_kernel_arg>(),
+        Op<opmath_t>());
+  });
 
   return tensor_lists[1];
 }
@@ -73,14 +76,17 @@
   std::vector<std::vector<at::Tensor>> tensor_lists;
   tensor_lists.emplace_back(tensors.vec());
   using opmath_t = typename at::opmath_type<scalar_t>;
-  multi_tensor_apply<1>(
-      tensor_lists,
-      UnaryOpFunctor<
-          scalar_t,
-          /* depth */ 1,
-          /* r_args_depth */ 1,
-          /* res_arg_index */ 0>(),
-      Op<opmath_t>());
+  DISPATCH_MULTI_TENSOR_APPLY([&]() {
+    multi_tensor_apply<1>(
+        tensor_lists,
+        UnaryOpFunctor<
+            scalar_t,
+            /* depth */ 1,
+            /* r_args_depth */ 1,
+            /* res_arg_index */ 0,
+            large_kernel_arg>(),
+        Op<opmath_t>());
+  });
   increment_version(tensors);
 }
 
@@ -395,13 +401,16 @@
       tensors[0].scalar_type(),
       "foreach_zero_cuda_",
       [&]() {
-        multi_tensor_apply<1>(
-            tensor_lists,
-            ZeroFunctor<
-                scalar_t,
-                /* depth */ 1,
-                /* r_args_depth */ 1,
-                /* res_arg_index */ 0>());
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<1>(
+              tensor_lists,
+              ZeroFunctor<
+                  scalar_t,
+                  /* depth */ 1,
+                  /* r_args_depth */ 1,
+                  /* res_arg_index */ 0,
+                  large_kernel_arg>());
+        });
       });
 }
 
diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu
index beca6f7..b8ab3af 100644
--- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu
+++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu
@@ -56,14 +56,15 @@
   }
 }
 
-template <typename scalar_t, int depth>
+template <typename scalar_t, int depth, bool large_kernel_arg>
 struct FusedSgdMathFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   static_assert(
       depth == 2 || depth == 3,
       "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
   C10_DEVICE __forceinline__ void operator()(
       const int chunk_size,
-      TensorListMetadata<depth>& tl,
+      TensorListMetadata<depth, large_kernel_arg>& tl,
       const double weight_decay,
       const double momentum,
       const float* lr_ptr,
@@ -172,19 +173,21 @@
       params[0].scalar_type(),
       "fused_sgd_with_momentum_kernel_cuda",
       [&]() {
-        multi_tensor_apply<3>(
-            tensor_lists,
-            FusedSgdMathFunctor<scalar_t, 3>(),
-            weight_decay,
-            momentum,
-            lr_ptr,
-            lr,
-            dampening,
-            nesterov,
-            maximize,
-            is_first_step,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3>(
+              tensor_lists,
+              FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
+              weight_decay,
+              momentum,
+              lr_ptr,
+              lr,
+              dampening,
+              nesterov,
+              maximize,
+              is_first_step,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -246,19 +249,21 @@
       params[0].scalar_type(),
       "fused_sgd_with_momentum_kernel_cuda",
       [&]() {
-        multi_tensor_apply<3>(
-            tensor_lists,
-            FusedSgdMathFunctor<scalar_t, 3>(),
-            weight_decay,
-            momentum,
-            lr.data_ptr<float>(),
-            1.0,
-            dampening,
-            nesterov,
-            maximize,
-            is_first_step,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<3>(
+              tensor_lists,
+              FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
+              weight_decay,
+              momentum,
+              lr.data_ptr<float>(),
+              1.0,
+              dampening,
+              nesterov,
+              maximize,
+              is_first_step,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -312,19 +317,21 @@
       params[0].scalar_type(),
       "fused_sgd_kernel_cuda",
       [&]() {
-        multi_tensor_apply<2>(
-            tensor_lists,
-            FusedSgdMathFunctor<scalar_t, 2>(),
-            weight_decay,
-            momentum,
-            lr_ptr,
-            lr,
-            dampening,
-            nesterov,
-            maximize,
-            /* is_first_step */ false,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<2>(
+              tensor_lists,
+              FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
+              weight_decay,
+              momentum,
+              lr_ptr,
+              lr,
+              dampening,
+              nesterov,
+              maximize,
+              /* is_first_step */ false,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -404,19 +411,21 @@
       params[0].scalar_type(),
       "fused_sgd_kernel_cuda",
       [&]() {
-        multi_tensor_apply<2>(
-            tensor_lists,
-            FusedSgdMathFunctor<scalar_t, 2>(),
-            weight_decay,
-            momentum,
-            lr.data_ptr<float>(),
-            1.0,
-            dampening,
-            nesterov,
-            maximize,
-            /* is_first_step */ false,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply<2>(
+              tensor_lists,
+              FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
+              weight_decay,
+              momentum,
+              lr.data_ptr<float>(),
+              1.0,
+              dampening,
+              nesterov,
+              maximize,
+              /* is_first_step */ false,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cpp b/aten/src/ATen/native/cuda/MultiTensorApply.cpp
new file mode 100644
index 0000000..208adb0
--- /dev/null
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cpp
@@ -0,0 +1,25 @@
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGraphsC10Utils.h>
+
+#include <cuda_runtime.h>
+
+namespace at::native {
+
+bool supports_large_kernel_arg() {
+#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010
+  static std::optional<bool> supports_large_kernel_arg_ = std::nullopt;
+  if (!supports_large_kernel_arg_.has_value()) {
+    int driver_ver = 0;
+    AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver));
+    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+    supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7;
+  }
+  const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() !=
+      at::cuda::CaptureStatus::None;
+  return !is_capturing && *supports_large_kernel_arg_;
+#else
+  return false;
+#endif
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
index 17f1444..9b9c3ee 100644
--- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
@@ -8,20 +8,105 @@
 
 namespace at::native {
 
+// NOTE: [32KB kernel argument size support]
+// 32KB kernel argument size support has three requirements:
+// - CUDART_VERSION >= 12010
+// - Driver version >= 530
+// - GPU arch >= VOLTA
+//
+// Due to minor version compatibility, it possible for binaries built with
+// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
+// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
+// to build both 4KB and 32KB kernels and determine the appropriate kernel to
+// dispatch at runtime.
+//
+// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
+//
+// - If CUDART_VERSION >= 12010:
+//   - Host code:
+//     - We always instantiate the launching stub for both 4KB and 32KB kernels.
+//   - Device code:
+//     - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
+//     kernels.
+//     - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
+//     32KB kernel (formal parameter space overflowed). Thus, we only
+//     instantiate a declaration for 32KB kernels. This is valid as long as the
+//     declaration-only kernel is not launched.
+//
+// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
+// GPU arch >= VOLTA.
+//
+// - TODO(yifu): once there's a CUDART version that is not compatible with any
+// driver version below 530, we can determine at compile time to not compile
+// the kernels for 4KB kernel argument size.
+//
+// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
+bool supports_large_kernel_arg();
+
 namespace {
 
 static constexpr int64_t kILP = 4;
 static constexpr int64_t kChunkSize = 65536;
 static constexpr int64_t kBlockSize = 512;
 
-// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
-// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
-static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
-static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
-static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
-static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
-    72,
-    60};
+// MSVC has a problem with constexpr and can't handle passing them to templates
+// as arguments. We need to replace it with const static.
+// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662
+#if !defined(_WIN32)
+#define SWITCH_TYPE constexpr bool
+#else
+#define SWITCH_TYPE const static bool
+#endif
+
+#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \
+    CUDART_VERSION >= 12010
+#define DISPATCH_MULTI_TENSOR_APPLY(...)             \
+  if (at::native::supports_large_kernel_arg()) {     \
+    SWITCH_TYPE large_kernel_arg C10_UNUSED = true;  \
+    __VA_ARGS__();                                   \
+  } else {                                           \
+    SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
+    __VA_ARGS__();                                   \
+  }
+#else
+#define DISPATCH_MULTI_TENSOR_APPLY(...)             \
+  do {                                               \
+    SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
+    __VA_ARGS__();                                   \
+  } while (0);
+#endif
+
+template <bool large_kernel_arg>
+struct DepthToMaxConfig;
+
+template <>
+struct DepthToMaxConfig<false> {
+  static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
+  static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
+  static constexpr int depth_to_max_tensors_scalarlist[5] =
+      {96, 64, 48, 36, 30};
+  static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
+      72,
+      60};
+  using TensorIdxType = unsigned char;
+};
+
+template <>
+struct DepthToMaxConfig<true> {
+  // TODO(yifu): These values are not yet optimally tuned. I simply multiplied
+  // the values tuned for 4KB kernel argument size limit by 7 (the kernel
+  // argument size limit increased by 8x but we need to change the type of
+  // block_to_tensor from unsigned char to uint16_t to support larger number of
+  // tensors).
+  static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210};
+  static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240};
+  static constexpr int depth_to_max_tensors_scalarlist[5] =
+      {672, 448, 336, 252, 210};
+  static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
+      504,
+      420};
+  using TensorIdxType = uint16_t;
+};
 
 template <typename T>
 __device__ __forceinline__ bool is_aligned(T* p) {
@@ -38,73 +123,101 @@
   ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
 }
 
-template <int n>
+template <int n, bool large_kernel_arg>
 struct TensorListMetadata {
-  const void* addresses[n][depth_to_max_tensors[n - 1]];
-  int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
-  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
-  int block_to_chunk[depth_to_max_blocks[n - 1]];
+  using Conf = DepthToMaxConfig<large_kernel_arg>;
+  const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
+  int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
+  typename Conf::TensorIdxType
+      block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+  int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
   int start_tensor_this_launch;
 };
 
-template <typename scalar_vals_t, int n>
+template <typename scalar_vals_t, int n, bool large_kernel_arg>
 struct TensorListScalarListMetadata {
-  const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
-  int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
-  scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
-  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
-  int block_to_chunk[depth_to_max_blocks[n - 1]];
+  using Conf = DepthToMaxConfig<large_kernel_arg>;
+  const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]];
+  int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+  scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+  typename Conf::TensorIdxType
+      block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+  int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
 };
 
 // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
 // 4kb with `c10::complex<double>`
-template <>
-struct TensorListScalarListMetadata<c10::complex<double>, 1> {
-  const void* addresses[1]
-                       [depth_to_max_tensors_scalarlist_of_complex_double[0]];
-  int64_t
-      numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
+template <bool large_kernel_arg>
+struct TensorListScalarListMetadata<c10::complex<double>, 1, large_kernel_arg> {
+  using Conf = DepthToMaxConfig<large_kernel_arg>;
+  const void*
+      addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
+  int64_t numel_for_tensor
+      [Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
   c10::complex<double>
-      scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
-  unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
-  int block_to_chunk[depth_to_max_blocks[1 - 1]];
+      scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
+  typename Conf::TensorIdxType
+      block_to_tensor[Conf::depth_to_max_blocks[1 - 1]];
+  int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]];
 };
 
-template <>
-struct TensorListScalarListMetadata<c10::complex<double>, 2> {
-  const void* addresses[2]
-                       [depth_to_max_tensors_scalarlist_of_complex_double[1]];
-  int64_t
-      numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
+template <bool large_kernel_arg>
+struct TensorListScalarListMetadata<c10::complex<double>, 2, large_kernel_arg> {
+  using Conf = DepthToMaxConfig<large_kernel_arg>;
+  const void*
+      addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
+  int64_t numel_for_tensor
+      [Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
   c10::complex<double>
-      scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
-  unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
-  int block_to_chunk[depth_to_max_blocks[2 - 1]];
+      scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
+  typename Conf::TensorIdxType
+      block_to_tensor[Conf::depth_to_max_blocks[2 - 1]];
+  int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]];
 };
 
 // NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
 // whose each element is `at::Tensor` of 1 element representing the number of
 // `step`s called so far.
-template <int n>
+template <int n, bool large_kernel_arg>
 struct FusedOptimizerTensorListMetadata {
-  const void* addresses[n][depth_to_max_tensors[n - 1]];
-  int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
-  const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
-  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
-  int block_to_chunk[depth_to_max_blocks[n - 1]];
+  using Conf = DepthToMaxConfig<large_kernel_arg>;
+  const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
+  int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
+  const void*
+      state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+  typename Conf::TensorIdxType
+      block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+  int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
   int start_tensor_this_launch;
 };
 
+#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
 template <typename T, typename U, typename... ArgTypes>
 C10_LAUNCH_BOUNDS_1(kBlockSize)
-__global__ void multi_tensor_apply_kernel(
-    T tensorListMeta,
-    U callable,
-    ArgTypes... args) {
+__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
+    multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
   // Hand the chunk information to the user-supplied functor to process however
   // it likes.
   callable(kChunkSize, tensorListMeta, args...);
 }
+#else
+// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a
+// declaration for the 32KB kernels.
+// For details see: [32KB kernel argument size support]
+#pragma nv_diag_suppress 114 // Function was referenced but not defined
+template <typename T, typename U, typename... ArgTypes>
+C10_LAUNCH_BOUNDS_1(kBlockSize)
+__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
+    multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args);
+#pragma nv_diag_default 114 // Function was referenced but not defined
+#endif
+
+template <typename T, typename U, typename... ArgTypes>
+C10_LAUNCH_BOUNDS_1(kBlockSize)
+__global__ typename std::enable_if<!U::use_large_kernel_arg, void>::type
+    multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
+  callable(kChunkSize, tensorListMeta, args...);
+}
 
 } // namespace
 
@@ -133,7 +246,10 @@
       "Number of tensor lists has to match the depth.");
   const size_t n_tensors = tensor_lists[0].size();
   using scalar_vals_t = typename T::opmath_t;
-  TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
+  TensorListScalarListMetadata<scalar_vals_t, depth, T::use_large_kernel_arg>
+      tensorListMeta;
+
+  using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
 
   int loc_block_info = 0;
   int loc_tensor_info = 0;
@@ -167,10 +283,11 @@
       // a tensor is not considered full unless all its chunks have been
       // processed
       const bool tensors_full =
-          (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
+          (loc_tensor_info ==
+               Conf::depth_to_max_tensors_scalarlist[depth - 1] &&
            chunk == chunks - 1);
       const bool blocks_full =
-          (loc_block_info == depth_to_max_blocks[depth - 1]);
+          (loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
 
       if (tensors_full || blocks_full) {
         multi_tensor_apply_kernel<<<
@@ -223,9 +340,11 @@
       tensor_lists.size() == depth,
       "Number of tensor lists has to match the depth.");
   const size_t n_tensors = tensor_lists[0].size();
-  TensorListMetadata<depth> tensorListMeta;
+  TensorListMetadata<depth, T::use_large_kernel_arg> tensorListMeta;
   tensorListMeta.start_tensor_this_launch = 0;
 
+  using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
+
   int loc_block_info = 0;
   int loc_tensor_info = 0;
   for (size_t t = 0; t < n_tensors; t++) {
@@ -250,10 +369,10 @@
       loc_block_info++;
 
       const bool tensors_full =
-          (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+          (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
            chunk == chunks - 1);
       const bool blocks_full =
-          (loc_block_info == depth_to_max_blocks[depth - 1]);
+          (loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
 
       if (tensors_full || blocks_full) {
         multi_tensor_apply_kernel<<<
@@ -304,7 +423,10 @@
       tensor_lists.size() == depth,
       "Number of tensor lists has to match the depth");
   const auto num_tensors = tensor_lists[0].size();
-  FusedOptimizerTensorListMetadata<depth> tensorListMeta;
+  FusedOptimizerTensorListMetadata<depth, T::use_large_kernel_arg>
+      tensorListMeta;
+
+  using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
 
   int loc_block_info = 0;
   int loc_tensor_info = 0;
@@ -333,9 +455,10 @@
       loc_block_info++;
 
       const auto tensor_full =
-          (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+          (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
            chunk == chunks - 1);
-      const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
+      const auto blocks_full =
+          loc_block_info == Conf::depth_to_max_blocks[depth - 1];
 
       if (tensor_full || blocks_full) {
         multi_tensor_apply_kernel<<<
diff --git a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
index cef07de..0c6318d 100644
--- a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
@@ -42,19 +42,26 @@
       params[0].scalar_type(),
       "fused_adam_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<5>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
-            lr_ptr, // unused
-            lr,
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<5>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  5,
+                  ADAM_MODE::ORIGINAL,
+                  true,
+                  large_kernel_arg>(),
+              lr_ptr, // unused
+              lr,
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -93,19 +100,26 @@
       params[0].scalar_type(),
       "fused_adam_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<5>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
-            lr_ptr,
-            1.0, // unused
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<5>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  5,
+                  ADAM_MODE::ORIGINAL,
+                  true,
+                  large_kernel_arg>(),
+              lr_ptr,
+              1.0, // unused
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
diff --git a/aten/src/ATen/native/cuda/fused_adam_impl.cu b/aten/src/ATen/native/cuda/fused_adam_impl.cu
index 2c1f5ce..43a4cbf 100644
--- a/aten/src/ATen/native/cuda/fused_adam_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adam_impl.cu
@@ -37,19 +37,26 @@
       params[0].scalar_type(),
       "fused_adam_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<4>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
-            lr_ptr, // unused
-            lr,
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<4>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  4,
+                  ADAM_MODE::ORIGINAL,
+                  false,
+                  large_kernel_arg>(),
+              lr_ptr, // unused
+              lr,
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -83,19 +90,26 @@
       params[0].scalar_type(),
       "fused_adam_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<4>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
-            lr_ptr,
-            1.0, // unused
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<4>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  4,
+                  ADAM_MODE::ORIGINAL,
+                  false,
+                  large_kernel_arg>(),
+              lr_ptr,
+              1.0, // unused
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh
index 1821959..43627fc 100644
--- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh
+++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh
@@ -102,15 +102,21 @@
 // parameter updates accordingly. To be functionally on par with `torch.optim`
 // optimizers and `_multi_tensor` ones, the kernel below writes out gradients
 // only when `grad_scale_ptr != nullptr.
-template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
+template <
+    typename scalar_type,
+    int depth,
+    ADAM_MODE adam_mode,
+    bool amsgrad,
+    bool large_kernel_arg>
 struct FusedAdamMathFunctor {
+  static constexpr bool use_large_kernel_arg = large_kernel_arg;
   static_assert(
       depth == 4 || depth == 5,
       "depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
   using opmath_t = at::opmath_type<scalar_type>;
   C10_DEVICE __forceinline__ void operator()(
       int chunk_size,
-      FusedOptimizerTensorListMetadata<depth>& tl,
+      FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
       const float* lr_ptr,
       const double& lr,
       const double& beta1,
diff --git a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
index 8a22b57..3b3b732 100644
--- a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
@@ -43,19 +43,26 @@
       params[0].scalar_type(),
       "fused_adamw_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<5>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
-            lr_ptr, // unused
-            lr,
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<5>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  5,
+                  ADAM_MODE::ADAMW,
+                  true,
+                  large_kernel_arg>(),
+              lr_ptr, // unused
+              lr,
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -94,19 +101,26 @@
       params[0].scalar_type(),
       "fused_adamw_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<5>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
-            lr_ptr,
-            1.0, // unused
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<5>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  5,
+                  ADAM_MODE::ADAMW,
+                  true,
+                  large_kernel_arg>(),
+              lr_ptr,
+              1.0, // unused
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
diff --git a/aten/src/ATen/native/cuda/fused_adamw_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_impl.cu
index b0f9dc6..ff65768 100644
--- a/aten/src/ATen/native/cuda/fused_adamw_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adamw_impl.cu
@@ -38,19 +38,26 @@
       params[0].scalar_type(),
       "fused_adamw_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<4>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
-            lr_ptr, // unused
-            lr,
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<4>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  4,
+                  ADAM_MODE::ADAMW,
+                  false,
+                  large_kernel_arg>(),
+              lr_ptr, // unused
+              lr,
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
@@ -84,19 +91,26 @@
       params[0].scalar_type(),
       "fused_adamw_kernel_cuda",
       [&]() {
-        multi_tensor_apply_for_fused_optimizer<4>(
-            tensor_lists,
-            state_steps,
-            FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
-            lr_ptr,
-            1.0, // unused
-            beta1,
-            beta2,
-            weight_decay,
-            eps,
-            maximize,
-            grad_scale_ptr,
-            found_inf_ptr);
+        DISPATCH_MULTI_TENSOR_APPLY([&]() {
+          multi_tensor_apply_for_fused_optimizer<4>(
+              tensor_lists,
+              state_steps,
+              FusedAdamMathFunctor<
+                  scalar_t,
+                  4,
+                  ADAM_MODE::ADAMW,
+                  false,
+                  large_kernel_arg>(),
+              lr_ptr,
+              1.0, // unused
+              beta1,
+              beta2,
+              weight_decay,
+              eps,
+              maximize,
+              grad_scale_ptr,
+              found_inf_ptr);
+        });
       });
 }
 
diff --git a/build_variables.bzl b/build_variables.bzl
index 7ef2e0f..bf4d737 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1456,6 +1456,7 @@
     "aten/src/ATen/native/cuda/Equal.cpp",
     "aten/src/ATen/native/cuda/GridSampler.cpp",
     "aten/src/ATen/native/cuda/IndexKernel.cpp",
+    "aten/src/ATen/native/cuda/MultiTensorApply.cpp",
     "aten/src/ATen/native/cuda/ReduceOps.cpp",
     "aten/src/ATen/native/cuda/ScanKernels.cpp",
     "aten/src/ATen/native/cuda/Sort.cpp",