Add support for 32KB multi_tensor_apply kernel arguments (#134373)
## Benchmark
On H100 SXM (HBM2e, 500W TDP), CUDA Toolkit=12.2, Driver Version=535.154.05, with [this script](https://gist.github.com/yifuwang/178c1f4bf951c5794ea79c04d90e44fa) (`torch._foreach_copy_`):
**Baseline**
```
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp0g_x4sys
device ms: 0.891, cpu ms: 7.200
memory bandwidth: 1457.727 GB/s
```
Single iteration trace:
<img width="1432" alt="image" src="https://github.com/user-attachments/assets/8ef54365-0265-4281-a0f0-d4c2f448300e">
**This PR**
```
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp3jqiugli
device ms: 0.683, cpu ms: 6.745
memory bandwidth: 1902.010 GB/s
```
Single iteration trace:
<img width="1074" alt="image" src="https://github.com/user-attachments/assets/e52acad1-d09b-492c-9611-6d69e339f3ac">
## Binary Size and Kernel Specialization
The binary size for `libtorch_cuda.so` increased 6MB (243MB -> 249MB).
```
// NOTE: [32KB kernel argument size support]
// 32KB kernel argument size support has three requirements:
// - CUDART_VERSION >= 12010
// - Driver version >= 530
// - GPU arch >= VOLTA
//
// Due to minor version compatibility, it possible for binaries built with
// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
// to build both 4KB and 32KB kernels and determine the appropriate kernel to
// dispatch at runtime.
//
// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
//
// - If CUDART_VERSION >= 12010:
// - Host code:
// - We always instantiate the launching stub for both 4KB and 32KB kernels.
// - Device code:
// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
// kernels.
// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
// 32KB kernel (formal parameter space overflowed). Thus, we only
// instantiate a declaration for 32KB kernels. This is valid as long as the
// declaration-only kernel is not launched.
//
// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
// GPU arch >= VOLTA.
//
// - TODO(yifu): once there's a CUDART version that is not compatible with any
// driver version below 530, we can determine at compile time to not compile
// the kernels for 4KB kernel argument size.
//
// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134373
Approved by: https://github.com/eqy, https://github.com/crcrpar, https://github.com/janeyx99
diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu
index 8c161ca..07cb84f 100644
--- a/aten/src/ATen/native/cuda/AmpKernels.cu
+++ b/aten/src/ATen/native/cuda/AmpKernels.cu
@@ -157,21 +157,24 @@
using opmath_t = at::opmath_type<scalar_t>;
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
- multi_tensor_apply<1>(tensor_lists,
- UnaryOpFunctor<scalar_t,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>(),
- [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
- // There is a slight asymmetry here with the TensorIterator kernel above.
- // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
- if (!isfinite_ensure_cuda_math(val)) {
- *found_inf_ptr = 1.f;
- }
- // Every thread accesses inv_scale, but it will hit in cache.
- const auto inv_scale_val = *inv_scale_ptr;
- return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
- });
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(tensor_lists,
+ UnaryOpFunctor<scalar_t,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
+ // There is a slight asymmetry here with the TensorIterator kernel above.
+ // MTA Functors ensure val comes in as opmath_t rather than scalar_t.
+ if (!isfinite_ensure_cuda_math(val)) {
+ *found_inf_ptr = 1.f;
+ }
+ // Every thread accesses inv_scale, but it will hit in cache.
+ const auto inv_scale_val = *inv_scale_ptr;
+ return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
+ });
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
index 533aa38..90c180c 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
@@ -41,15 +41,18 @@
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<3>(
- tensor_lists,
- BinaryOpListAlphaFunctor<
- T,
- /* depth */ 3,
- /* r_args_depth */ 2,
- /* res_arg_index */ 2>(),
- Op<opmath_t>(),
- alpha.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ BinaryOpListAlphaFunctor<
+ T,
+ /* depth */ 3,
+ /* r_args_depth */ 2,
+ /* res_arg_index */ 2,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ alpha.to<opmath_t>());
+ });
return tensor_lists[2];
}
@@ -64,15 +67,18 @@
tensor_lists.emplace_back(tensors2.vec());
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<2>(
- tensor_lists,
- BinaryOpListAlphaFunctor<
- T,
- /* depth */ 2,
- /* r_args_depth */ 2,
- /* res_arg_index */ 0>(),
- Op<opmath_t>(),
- alpha.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ BinaryOpListAlphaFunctor<
+ T,
+ /* depth */ 2,
+ /* r_args_depth */ 2,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ alpha.to<opmath_t>());
+ });
increment_version(tensors1);
}
@@ -331,13 +337,15 @@
typename src_t,
int depth,
int r_args_depth,
- int res_arg_index>
+ int res_arg_index,
+ bool large_kernel_arg>
struct CopyFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -420,14 +428,17 @@
using opmath_t = at::opmath_type<scalar_t>;
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
if constexpr (std::is_same_v<scalar_t, src_t>) {
- multi_tensor_apply<2>(
- tensor_lists,
- UnaryOpFunctor<
- scalar_t,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Copy<opmath_t, opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ UnaryOpFunctor<
+ scalar_t,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
+ Copy<opmath_t, opmath_t>());
+ });
} else {
// Ref:
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
@@ -435,15 +446,18 @@
TORCH_WARN_ONCE(
"Casting complex values to real discards the imaginary part");
}
- multi_tensor_apply<2>(
- tensor_lists,
- CopyFunctor<
- scalar_t,
- src_t,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Copy<scalar_t, src_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ CopyFunctor<
+ scalar_t,
+ src_t,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
+ Copy<scalar_t, src_t>());
+ });
}
});
});
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
index 80d748d..1b0045d 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
@@ -36,15 +36,18 @@
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<2>(
- tensor_lists,
- BinaryOpScalarFunctor<
- T,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Op<opmath_t>(),
- scalar.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ BinaryOpScalarFunctor<
+ T,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.to<opmath_t>());
+ });
return tensor_lists[1];
}
@@ -54,15 +57,18 @@
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<1>(
- tensor_lists,
- BinaryOpScalarFunctor<
- T,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>(),
- Op<opmath_t>(),
- scalar.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ BinaryOpScalarFunctor<
+ T,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.to<opmath_t>());
+ });
increment_version(tensors);
}
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
index dcb9318..c423006 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
@@ -36,16 +36,19 @@
tensor_lists.emplace_back(vec_res);
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<2, opmath_t>(
- tensor_lists,
- scalars,
- BinaryOpScalarListFunctor<
- T,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2, opmath_t>(
+ tensor_lists,
+ scalars,
+ BinaryOpScalarListFunctor<
+ T,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
- Op<opmath_t>());
+ Op<opmath_t>());
+ });
return tensor_lists[1];
}
@@ -55,15 +58,18 @@
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<1, opmath_t>(
- tensor_lists,
- scalars,
- BinaryOpScalarListFunctor<
- T,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>(),
- Op<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1, opmath_t>(
+ tensor_lists,
+ scalars,
+ BinaryOpScalarListFunctor<
+ T,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>());
+ });
increment_version(tensors);
}
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
index ad5eeee..4163b30 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
@@ -46,16 +46,19 @@
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<2>(
- tensor_lists,
- BinaryOpScalarTensorFunctor<
- T,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Op<opmath_t>(),
- scalar.data_ptr<T>(),
- alpha.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ BinaryOpScalarTensorFunctor<
+ T,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.data_ptr<T>(),
+ alpha.to<opmath_t>());
+ });
return tensor_lists[1];
}
@@ -81,16 +84,19 @@
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
- multi_tensor_apply<1>(
- tensor_lists,
- BinaryOpScalarTensorFunctor<
- T,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>(),
- Op<opmath_t>(),
- scalar.data_ptr<T>(),
- alpha.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ BinaryOpScalarTensorFunctor<
+ T,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.data_ptr<T>(),
+ alpha.to<opmath_t>());
+ });
increment_version(tensors);
}
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
index 55e4fd7..fb8d849 100644
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -18,10 +18,10 @@
}
// Initializes args and checks if all args are aligned
-template <int depth, typename T>
+template <int depth, typename T, bool large_kernel_arg>
__device__ bool init_args(
T** args,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@@ -38,10 +38,10 @@
}
// Initializes args and checks if all args are aligned
-template <int depth, typename T, typename T2>
+template <int depth, typename T, typename T2, bool large_kernel_arg>
__device__ bool init_args(
T** args,
- TensorListScalarListMetadata<T2, depth>& tl,
+ TensorListScalarListMetadata<T2, depth, large_kernel_arg>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@@ -57,10 +57,10 @@
return all_aligned;
}
-template <int depth, typename T>
+template <int depth, typename T, bool large_kernel_arg>
__device__ bool init_args(
T** args,
- FusedOptimizerTensorListMetadata<depth>& tl,
+ FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@@ -203,13 +203,19 @@
//
// Binary Functors
//
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct BinaryOpScalarFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op,
opmath_t scalar) {
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -227,13 +233,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct BinaryOpScalarListFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListScalarListMetadata<opmath_t, depth>& tl,
+ TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -251,13 +263,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct BinaryOpListAlphaFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op,
opmath_t alpha) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -303,13 +321,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct BinaryOpScalarTensorFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op,
T* scalar,
opmath_t alpha) {
@@ -361,11 +385,17 @@
// Unary Functors
//
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct ZeroFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<1>& tl) {
+ TensorListMetadata<1, large_kernel_arg>& tl) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
@@ -401,13 +431,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct UnaryOpFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -453,13 +489,19 @@
// Pointwise Functors
//
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct PointwiseOpScalarFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op,
opmath_t scalar) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -477,13 +519,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct PointwiseOpScalarListFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListScalarListMetadata<opmath_t, depth>& tl,
+ TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -501,13 +549,14 @@
}
};
-template <typename T, int depth>
+template <typename T, int depth, bool large_kernel_arg>
struct PointwiseOpListFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -552,13 +601,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct TernaryOpListFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op) {
static_assert(depth == 3 || depth == 4, "");
static_assert(depth >= r_args_depth, "");
@@ -606,13 +661,19 @@
}
};
-template <typename T, int depth, int r_args_depth, int res_arg_index>
+template <
+ typename T,
+ int depth,
+ int r_args_depth,
+ int res_arg_index,
+ bool large_kernel_arg>
struct TernaryOpScalarFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
Op op,
opmath_t alpha) {
static_assert(depth == 2 || depth == 3, "");
diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
index 7a3276c..bad152c 100644
--- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
@@ -46,15 +46,18 @@
"foreach_pointwise_op_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
- multi_tensor_apply<4>(
- tensor_lists,
- PointwiseOpScalarFunctor<
- scalar_t,
- /* depth */ 4,
- /* r_args_depth */ 3,
- /* res_arg_index */ 3>(),
- Op<opmath_t>(),
- scalar.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<4>(
+ tensor_lists,
+ PointwiseOpScalarFunctor<
+ scalar_t,
+ /* depth */ 4,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 3,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.to<opmath_t>());
+ });
});
return tensor_lists[3];
@@ -78,15 +81,18 @@
"foreach_pointwise_op__cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
- multi_tensor_apply<3>(
- tensor_lists,
- PointwiseOpScalarFunctor<
- scalar_t,
- /* depth */ 3,
- /* r_args_depth */ 3,
- /* res_arg_index */ 0>(),
- Op<opmath_t>(),
- scalar.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ PointwiseOpScalarFunctor<
+ scalar_t,
+ /* depth */ 3,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>(),
+ scalar.to<opmath_t>());
+ });
});
increment_version(input);
}
@@ -110,15 +116,18 @@
"foreach_pointwise_op__cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
- multi_tensor_apply<3, opmath_t>(
- tensor_lists,
- scalars,
- PointwiseOpScalarListFunctor<
- scalar_t,
- /* depth */ 3,
- /* r_args_depth */ 3,
- /* res_arg_index */ 0>(),
- Op<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3, opmath_t>(
+ tensor_lists,
+ scalars,
+ PointwiseOpScalarListFunctor<
+ scalar_t,
+ /* depth */ 3,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>());
+ });
});
increment_version(input);
}
@@ -149,15 +158,18 @@
"foreach_pointwise_op_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
- multi_tensor_apply<4, opmath_t>(
- tensor_lists,
- scalars,
- PointwiseOpScalarListFunctor<
- scalar_t,
- /* depth */ 4,
- /* r_args_depth */ 3,
- /* res_arg_index */ 3>(),
- Op<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<4, opmath_t>(
+ tensor_lists,
+ scalars,
+ PointwiseOpScalarListFunctor<
+ scalar_t,
+ /* depth */ 4,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 3,
+ large_kernel_arg>(),
+ Op<opmath_t>());
+ });
});
return tensor_lists[3];
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index 61793fd..1ce18b7 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -50,11 +50,13 @@
typename T,
int depth = 1,
int r_args_depth = 1,
- int res_arg_index = 0>
+ int res_arg_index = 0,
+ bool large_kernel_arg = false>
struct LpMaxFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
T* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -178,11 +180,13 @@
tensor_lists[0][0].scalar_type(),
"foreach_tensor_max_cuda_scalar_type",
[&]() {
- multi_tensor_apply<1>(
- tensor_lists,
- LpMaxFunctor<scalar_t>(),
- output_per_tensor.mutable_data_ptr<scalar_t>(),
- max_chunks_per_tensor);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpMaxFunctor<scalar_t, 1, 1, 0, large_kernel_arg>(),
+ output_per_tensor.mutable_data_ptr<scalar_t>(),
+ max_chunks_per_tensor);
+ });
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
@@ -239,12 +243,14 @@
typename out_t,
int depth = 1,
int r_args_depth = 1,
- int res_arg_index = 0>
+ int res_arg_index = 0,
+ bool large_kernel_arg = false>
struct LpNormFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
using out_opmath_t = typename at::opmath_type<out_t>;
__device__ __forceinline__ void operator()(
int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
out_opmath_t* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@@ -476,23 +482,50 @@
output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
using out_opmath_t = typename at::opmath_type<out_t>;
if (p == static_cast<double>(1)) {
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::L1, out_t>(),
- output_per_tensor.mutable_data_ptr<out_opmath_t>(),
- max_chunks_per_tensor);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<
+ scalar_t,
+ NormType::L1,
+ out_t,
+ 1,
+ 1,
+ 0,
+ large_kernel_arg>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+ max_chunks_per_tensor);
+ });
} else if (p == static_cast<double>(2)) {
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::L2, out_t>(),
- output_per_tensor.mutable_data_ptr<out_opmath_t>(),
- max_chunks_per_tensor);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<
+ scalar_t,
+ NormType::L2,
+ out_t,
+ 1,
+ 1,
+ 0,
+ large_kernel_arg>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+ max_chunks_per_tensor);
+ });
} else if (p == std::numeric_limits<double>::infinity()) {
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
- output_per_tensor.mutable_data_ptr<out_opmath_t>(),
- max_chunks_per_tensor);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<
+ scalar_t,
+ NormType::LInf,
+ out_t,
+ 1,
+ 1,
+ 0,
+ large_kernel_arg>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
+ max_chunks_per_tensor);
+ });
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
diff --git a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
index e13f201..e73862d 100644
--- a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu
@@ -46,14 +46,17 @@
"foreach_tensor_lerp_ternary_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<4>(
- tensor_lists,
- TernaryOpListFunctor<
- scalar_t,
- /* depth */ 4,
- /* r_args_depth */ 3,
- /* res_arg_index */ 3>(),
- LerpFunctor<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<4>(
+ tensor_lists,
+ TernaryOpListFunctor<
+ scalar_t,
+ /* depth */ 4,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 3,
+ large_kernel_arg>(),
+ LerpFunctor<opmath_t>());
+ });
});
return tensor_lists[3];
@@ -77,14 +80,17 @@
"foreach_tensor_lerp_ternary_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<3>(
- tensor_lists,
- TernaryOpListFunctor<
- scalar_t,
- /* depth */ 3,
- /* r_args_depth */ 3,
- /* res_arg_index */ 0>(),
- LerpFunctor<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ TernaryOpListFunctor<
+ scalar_t,
+ /* depth */ 3,
+ /* r_args_depth */ 3,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ LerpFunctor<opmath_t>());
+ });
});
increment_version(tensors1);
}
@@ -113,15 +119,18 @@
"foreach_tensor_lerp_scalar_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<3>(
- tensor_lists,
- TernaryOpScalarFunctor<
- scalar_t,
- /* depth */ 3,
- /* r_args_depth */ 2,
- /* res_arg_index */ 2>(),
- LerpFunctor<opmath_t>(),
- weight.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ TernaryOpScalarFunctor<
+ scalar_t,
+ /* depth */ 3,
+ /* r_args_depth */ 2,
+ /* res_arg_index */ 2,
+ large_kernel_arg>(),
+ LerpFunctor<opmath_t>(),
+ weight.to<opmath_t>());
+ });
});
return tensor_lists[2];
@@ -145,15 +154,18 @@
"foreach_tensor_lerp_scalar_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<2>(
- tensor_lists,
- TernaryOpScalarFunctor<
- scalar_t,
- /* depth */ 2,
- /* r_args_depth */ 2,
- /* res_arg_index */ 0>(),
- LerpFunctor<opmath_t>(),
- weight.to<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ TernaryOpScalarFunctor<
+ scalar_t,
+ /* depth */ 2,
+ /* r_args_depth */ 2,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ LerpFunctor<opmath_t>(),
+ weight.to<opmath_t>());
+ });
});
}
} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
index 1a969cf..fa333b4 100644
--- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
@@ -56,14 +56,17 @@
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<2>(
- tensor_lists,
- UnaryOpFunctor<
- scalar_t,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Op<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ UnaryOpFunctor<
+ scalar_t,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1,
+ large_kernel_arg>(),
+ Op<opmath_t>());
+ });
return tensor_lists[1];
}
@@ -73,14 +76,17 @@
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());
using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<1>(
- tensor_lists,
- UnaryOpFunctor<
- scalar_t,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>(),
- Op<opmath_t>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ UnaryOpFunctor<
+ scalar_t,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>(),
+ Op<opmath_t>());
+ });
increment_version(tensors);
}
@@ -395,13 +401,16 @@
tensors[0].scalar_type(),
"foreach_zero_cuda_",
[&]() {
- multi_tensor_apply<1>(
- tensor_lists,
- ZeroFunctor<
- scalar_t,
- /* depth */ 1,
- /* r_args_depth */ 1,
- /* res_arg_index */ 0>());
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ ZeroFunctor<
+ scalar_t,
+ /* depth */ 1,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 0,
+ large_kernel_arg>());
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu
index beca6f7..b8ab3af 100644
--- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu
+++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu
@@ -56,14 +56,15 @@
}
}
-template <typename scalar_t, int depth>
+template <typename scalar_t, int depth, bool large_kernel_arg>
struct FusedSgdMathFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(
depth == 2 || depth == 3,
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
C10_DEVICE __forceinline__ void operator()(
const int chunk_size,
- TensorListMetadata<depth>& tl,
+ TensorListMetadata<depth, large_kernel_arg>& tl,
const double weight_decay,
const double momentum,
const float* lr_ptr,
@@ -172,19 +173,21 @@
params[0].scalar_type(),
"fused_sgd_with_momentum_kernel_cuda",
[&]() {
- multi_tensor_apply<3>(
- tensor_lists,
- FusedSgdMathFunctor<scalar_t, 3>(),
- weight_decay,
- momentum,
- lr_ptr,
- lr,
- dampening,
- nesterov,
- maximize,
- is_first_step,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
+ weight_decay,
+ momentum,
+ lr_ptr,
+ lr,
+ dampening,
+ nesterov,
+ maximize,
+ is_first_step,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -246,19 +249,21 @@
params[0].scalar_type(),
"fused_sgd_with_momentum_kernel_cuda",
[&]() {
- multi_tensor_apply<3>(
- tensor_lists,
- FusedSgdMathFunctor<scalar_t, 3>(),
- weight_decay,
- momentum,
- lr.data_ptr<float>(),
- 1.0,
- dampening,
- nesterov,
- maximize,
- is_first_step,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<3>(
+ tensor_lists,
+ FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
+ weight_decay,
+ momentum,
+ lr.data_ptr<float>(),
+ 1.0,
+ dampening,
+ nesterov,
+ maximize,
+ is_first_step,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -312,19 +317,21 @@
params[0].scalar_type(),
"fused_sgd_kernel_cuda",
[&]() {
- multi_tensor_apply<2>(
- tensor_lists,
- FusedSgdMathFunctor<scalar_t, 2>(),
- weight_decay,
- momentum,
- lr_ptr,
- lr,
- dampening,
- nesterov,
- maximize,
- /* is_first_step */ false,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
+ weight_decay,
+ momentum,
+ lr_ptr,
+ lr,
+ dampening,
+ nesterov,
+ maximize,
+ /* is_first_step */ false,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -404,19 +411,21 @@
params[0].scalar_type(),
"fused_sgd_kernel_cuda",
[&]() {
- multi_tensor_apply<2>(
- tensor_lists,
- FusedSgdMathFunctor<scalar_t, 2>(),
- weight_decay,
- momentum,
- lr.data_ptr<float>(),
- 1.0,
- dampening,
- nesterov,
- maximize,
- /* is_first_step */ false,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply<2>(
+ tensor_lists,
+ FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
+ weight_decay,
+ momentum,
+ lr.data_ptr<float>(),
+ 1.0,
+ dampening,
+ nesterov,
+ maximize,
+ /* is_first_step */ false,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cpp b/aten/src/ATen/native/cuda/MultiTensorApply.cpp
new file mode 100644
index 0000000..208adb0
--- /dev/null
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cpp
@@ -0,0 +1,25 @@
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGraphsC10Utils.h>
+
+#include <cuda_runtime.h>
+
+namespace at::native {
+
+bool supports_large_kernel_arg() {
+#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010
+ static std::optional<bool> supports_large_kernel_arg_ = std::nullopt;
+ if (!supports_large_kernel_arg_.has_value()) {
+ int driver_ver = 0;
+ AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver));
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7;
+ }
+ const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() !=
+ at::cuda::CaptureStatus::None;
+ return !is_capturing && *supports_large_kernel_arg_;
+#else
+ return false;
+#endif
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
index 17f1444..9b9c3ee 100644
--- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
@@ -8,20 +8,105 @@
namespace at::native {
+// NOTE: [32KB kernel argument size support]
+// 32KB kernel argument size support has three requirements:
+// - CUDART_VERSION >= 12010
+// - Driver version >= 530
+// - GPU arch >= VOLTA
+//
+// Due to minor version compatibility, it possible for binaries built with
+// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
+// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
+// to build both 4KB and 32KB kernels and determine the appropriate kernel to
+// dispatch at runtime.
+//
+// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
+//
+// - If CUDART_VERSION >= 12010:
+// - Host code:
+// - We always instantiate the launching stub for both 4KB and 32KB kernels.
+// - Device code:
+// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
+// kernels.
+// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
+// 32KB kernel (formal parameter space overflowed). Thus, we only
+// instantiate a declaration for 32KB kernels. This is valid as long as the
+// declaration-only kernel is not launched.
+//
+// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
+// GPU arch >= VOLTA.
+//
+// - TODO(yifu): once there's a CUDART version that is not compatible with any
+// driver version below 530, we can determine at compile time to not compile
+// the kernels for 4KB kernel argument size.
+//
+// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
+bool supports_large_kernel_arg();
+
namespace {
static constexpr int64_t kILP = 4;
static constexpr int64_t kChunkSize = 65536;
static constexpr int64_t kBlockSize = 512;
-// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
-// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
-static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
-static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
-static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
-static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
- 72,
- 60};
+// MSVC has a problem with constexpr and can't handle passing them to templates
+// as arguments. We need to replace it with const static.
+// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662
+#if !defined(_WIN32)
+#define SWITCH_TYPE constexpr bool
+#else
+#define SWITCH_TYPE const static bool
+#endif
+
+#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \
+ CUDART_VERSION >= 12010
+#define DISPATCH_MULTI_TENSOR_APPLY(...) \
+ if (at::native::supports_large_kernel_arg()) { \
+ SWITCH_TYPE large_kernel_arg C10_UNUSED = true; \
+ __VA_ARGS__(); \
+ } else { \
+ SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
+ __VA_ARGS__(); \
+ }
+#else
+#define DISPATCH_MULTI_TENSOR_APPLY(...) \
+ do { \
+ SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
+ __VA_ARGS__(); \
+ } while (0);
+#endif
+
+template <bool large_kernel_arg>
+struct DepthToMaxConfig;
+
+template <>
+struct DepthToMaxConfig<false> {
+ static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
+ static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
+ static constexpr int depth_to_max_tensors_scalarlist[5] =
+ {96, 64, 48, 36, 30};
+ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
+ 72,
+ 60};
+ using TensorIdxType = unsigned char;
+};
+
+template <>
+struct DepthToMaxConfig<true> {
+ // TODO(yifu): These values are not yet optimally tuned. I simply multiplied
+ // the values tuned for 4KB kernel argument size limit by 7 (the kernel
+ // argument size limit increased by 8x but we need to change the type of
+ // block_to_tensor from unsigned char to uint16_t to support larger number of
+ // tensors).
+ static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210};
+ static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240};
+ static constexpr int depth_to_max_tensors_scalarlist[5] =
+ {672, 448, 336, 252, 210};
+ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
+ 504,
+ 420};
+ using TensorIdxType = uint16_t;
+};
template <typename T>
__device__ __forceinline__ bool is_aligned(T* p) {
@@ -38,73 +123,101 @@
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
-template <int n>
+template <int n, bool large_kernel_arg>
struct TensorListMetadata {
- const void* addresses[n][depth_to_max_tensors[n - 1]];
- int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
- unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
- int block_to_chunk[depth_to_max_blocks[n - 1]];
+ using Conf = DepthToMaxConfig<large_kernel_arg>;
+ const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
+ int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
+ typename Conf::TensorIdxType
+ block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+ int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
-template <typename scalar_vals_t, int n>
+template <typename scalar_vals_t, int n, bool large_kernel_arg>
struct TensorListScalarListMetadata {
- const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
- int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
- scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
- unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
- int block_to_chunk[depth_to_max_blocks[n - 1]];
+ using Conf = DepthToMaxConfig<large_kernel_arg>;
+ const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]];
+ int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+ scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+ typename Conf::TensorIdxType
+ block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+ int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
};
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
// 4kb with `c10::complex<double>`
-template <>
-struct TensorListScalarListMetadata<c10::complex<double>, 1> {
- const void* addresses[1]
- [depth_to_max_tensors_scalarlist_of_complex_double[0]];
- int64_t
- numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
+template <bool large_kernel_arg>
+struct TensorListScalarListMetadata<c10::complex<double>, 1, large_kernel_arg> {
+ using Conf = DepthToMaxConfig<large_kernel_arg>;
+ const void*
+ addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
+ int64_t numel_for_tensor
+ [Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
c10::complex<double>
- scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
- unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
- int block_to_chunk[depth_to_max_blocks[1 - 1]];
+ scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
+ typename Conf::TensorIdxType
+ block_to_tensor[Conf::depth_to_max_blocks[1 - 1]];
+ int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]];
};
-template <>
-struct TensorListScalarListMetadata<c10::complex<double>, 2> {
- const void* addresses[2]
- [depth_to_max_tensors_scalarlist_of_complex_double[1]];
- int64_t
- numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
+template <bool large_kernel_arg>
+struct TensorListScalarListMetadata<c10::complex<double>, 2, large_kernel_arg> {
+ using Conf = DepthToMaxConfig<large_kernel_arg>;
+ const void*
+ addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
+ int64_t numel_for_tensor
+ [Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
c10::complex<double>
- scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
- unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
- int block_to_chunk[depth_to_max_blocks[2 - 1]];
+ scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
+ typename Conf::TensorIdxType
+ block_to_tensor[Conf::depth_to_max_blocks[2 - 1]];
+ int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]];
};
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
// whose each element is `at::Tensor` of 1 element representing the number of
// `step`s called so far.
-template <int n>
+template <int n, bool large_kernel_arg>
struct FusedOptimizerTensorListMetadata {
- const void* addresses[n][depth_to_max_tensors[n - 1]];
- int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
- const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
- unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
- int block_to_chunk[depth_to_max_blocks[n - 1]];
+ using Conf = DepthToMaxConfig<large_kernel_arg>;
+ const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
+ int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
+ const void*
+ state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]];
+ typename Conf::TensorIdxType
+ block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
+ int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
+#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
-__global__ void multi_tensor_apply_kernel(
- T tensorListMeta,
- U callable,
- ArgTypes... args) {
+__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
+ multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(kChunkSize, tensorListMeta, args...);
}
+#else
+// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a
+// declaration for the 32KB kernels.
+// For details see: [32KB kernel argument size support]
+#pragma nv_diag_suppress 114 // Function was referenced but not defined
+template <typename T, typename U, typename... ArgTypes>
+C10_LAUNCH_BOUNDS_1(kBlockSize)
+__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
+ multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args);
+#pragma nv_diag_default 114 // Function was referenced but not defined
+#endif
+
+template <typename T, typename U, typename... ArgTypes>
+C10_LAUNCH_BOUNDS_1(kBlockSize)
+__global__ typename std::enable_if<!U::use_large_kernel_arg, void>::type
+ multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
+ callable(kChunkSize, tensorListMeta, args...);
+}
} // namespace
@@ -133,7 +246,10 @@
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
using scalar_vals_t = typename T::opmath_t;
- TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
+ TensorListScalarListMetadata<scalar_vals_t, depth, T::use_large_kernel_arg>
+ tensorListMeta;
+
+ using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
int loc_block_info = 0;
int loc_tensor_info = 0;
@@ -167,10 +283,11 @@
// a tensor is not considered full unless all its chunks have been
// processed
const bool tensors_full =
- (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
+ (loc_tensor_info ==
+ Conf::depth_to_max_tensors_scalarlist[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
- (loc_block_info == depth_to_max_blocks[depth - 1]);
+ (loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
@@ -223,9 +340,11 @@
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
- TensorListMetadata<depth> tensorListMeta;
+ TensorListMetadata<depth, T::use_large_kernel_arg> tensorListMeta;
tensorListMeta.start_tensor_this_launch = 0;
+ using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
+
int loc_block_info = 0;
int loc_tensor_info = 0;
for (size_t t = 0; t < n_tensors; t++) {
@@ -250,10 +369,10 @@
loc_block_info++;
const bool tensors_full =
- (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+ (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
- (loc_block_info == depth_to_max_blocks[depth - 1]);
+ (loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
@@ -304,7 +423,10 @@
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
const auto num_tensors = tensor_lists[0].size();
- FusedOptimizerTensorListMetadata<depth> tensorListMeta;
+ FusedOptimizerTensorListMetadata<depth, T::use_large_kernel_arg>
+ tensorListMeta;
+
+ using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
int loc_block_info = 0;
int loc_tensor_info = 0;
@@ -333,9 +455,10 @@
loc_block_info++;
const auto tensor_full =
- (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+ (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
- const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
+ const auto blocks_full =
+ loc_block_info == Conf::depth_to_max_blocks[depth - 1];
if (tensor_full || blocks_full) {
multi_tensor_apply_kernel<<<
diff --git a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
index cef07de..0c6318d 100644
--- a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu
@@ -42,19 +42,26 @@
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<5>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
- lr_ptr, // unused
- lr,
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<5>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 5,
+ ADAM_MODE::ORIGINAL,
+ true,
+ large_kernel_arg>(),
+ lr_ptr, // unused
+ lr,
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -93,19 +100,26 @@
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<5>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
- lr_ptr,
- 1.0, // unused
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<5>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 5,
+ ADAM_MODE::ORIGINAL,
+ true,
+ large_kernel_arg>(),
+ lr_ptr,
+ 1.0, // unused
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/fused_adam_impl.cu b/aten/src/ATen/native/cuda/fused_adam_impl.cu
index 2c1f5ce..43a4cbf 100644
--- a/aten/src/ATen/native/cuda/fused_adam_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adam_impl.cu
@@ -37,19 +37,26 @@
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<4>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
- lr_ptr, // unused
- lr,
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<4>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 4,
+ ADAM_MODE::ORIGINAL,
+ false,
+ large_kernel_arg>(),
+ lr_ptr, // unused
+ lr,
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -83,19 +90,26 @@
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<4>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
- lr_ptr,
- 1.0, // unused
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<4>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 4,
+ ADAM_MODE::ORIGINAL,
+ false,
+ large_kernel_arg>(),
+ lr_ptr,
+ 1.0, // unused
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh
index 1821959..43627fc 100644
--- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh
+++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh
@@ -102,15 +102,21 @@
// parameter updates accordingly. To be functionally on par with `torch.optim`
// optimizers and `_multi_tensor` ones, the kernel below writes out gradients
// only when `grad_scale_ptr != nullptr.
-template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
+template <
+ typename scalar_type,
+ int depth,
+ ADAM_MODE adam_mode,
+ bool amsgrad,
+ bool large_kernel_arg>
struct FusedAdamMathFunctor {
+ static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(
depth == 4 || depth == 5,
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
using opmath_t = at::opmath_type<scalar_type>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
- FusedOptimizerTensorListMetadata<depth>& tl,
+ FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
const float* lr_ptr,
const double& lr,
const double& beta1,
diff --git a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
index 8a22b57..3b3b732 100644
--- a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu
@@ -43,19 +43,26 @@
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<5>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
- lr_ptr, // unused
- lr,
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<5>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 5,
+ ADAM_MODE::ADAMW,
+ true,
+ large_kernel_arg>(),
+ lr_ptr, // unused
+ lr,
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -94,19 +101,26 @@
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<5>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
- lr_ptr,
- 1.0, // unused
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<5>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 5,
+ ADAM_MODE::ADAMW,
+ true,
+ large_kernel_arg>(),
+ lr_ptr,
+ 1.0, // unused
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
diff --git a/aten/src/ATen/native/cuda/fused_adamw_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_impl.cu
index b0f9dc6..ff65768 100644
--- a/aten/src/ATen/native/cuda/fused_adamw_impl.cu
+++ b/aten/src/ATen/native/cuda/fused_adamw_impl.cu
@@ -38,19 +38,26 @@
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<4>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
- lr_ptr, // unused
- lr,
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<4>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 4,
+ ADAM_MODE::ADAMW,
+ false,
+ large_kernel_arg>(),
+ lr_ptr, // unused
+ lr,
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
@@ -84,19 +91,26 @@
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
- multi_tensor_apply_for_fused_optimizer<4>(
- tensor_lists,
- state_steps,
- FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
- lr_ptr,
- 1.0, // unused
- beta1,
- beta2,
- weight_decay,
- eps,
- maximize,
- grad_scale_ptr,
- found_inf_ptr);
+ DISPATCH_MULTI_TENSOR_APPLY([&]() {
+ multi_tensor_apply_for_fused_optimizer<4>(
+ tensor_lists,
+ state_steps,
+ FusedAdamMathFunctor<
+ scalar_t,
+ 4,
+ ADAM_MODE::ADAMW,
+ false,
+ large_kernel_arg>(),
+ lr_ptr,
+ 1.0, // unused
+ beta1,
+ beta2,
+ weight_decay,
+ eps,
+ maximize,
+ grad_scale_ptr,
+ found_inf_ptr);
+ });
});
}
diff --git a/build_variables.bzl b/build_variables.bzl
index 7ef2e0f..bf4d737 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1456,6 +1456,7 @@
"aten/src/ATen/native/cuda/Equal.cpp",
"aten/src/ATen/native/cuda/GridSampler.cpp",
"aten/src/ATen/native/cuda/IndexKernel.cpp",
+ "aten/src/ATen/native/cuda/MultiTensorApply.cpp",
"aten/src/ATen/native/cuda/ReduceOps.cpp",
"aten/src/ATen/native/cuda/ScanKernels.cpp",
"aten/src/ATen/native/cuda/Sort.cpp",