Use CUDA DSA in caffe2/operators (#95299)

Differential Revision: D42977333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95299
Approved by: https://github.com/ezyang, https://github.com/malfet
diff --git a/caffe2/operators/batch_permutation_op.cu b/caffe2/operators/batch_permutation_op.cu
index 96bbbdb..81d6d6b 100644
--- a/caffe2/operators/batch_permutation_op.cu
+++ b/caffe2/operators/batch_permutation_op.cu
@@ -1,6 +1,8 @@
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/operators/batch_permutation_op.h"
 
+#include <c10/cuda/CUDADeviceAssertion.h>
+
 namespace caffe2 {
 
 namespace {
@@ -10,14 +12,15 @@
     int K,
     const float* src,
     const int* indices,
-    float* dst) {
+    float* dst,
+    TORCH_DSA_KERNEL_ARGS) {
   if (forward) {
     CUDA_1D_KERNEL_LOOP(index, N * K) {
       int k = index % K;
       int n = index / K;
       int idx = indices[n];
-      CUDA_KERNEL_ASSERT(idx >= 0);
-      CUDA_KERNEL_ASSERT(idx < N);
+      CUDA_KERNEL_ASSERT2(idx >= 0);
+      CUDA_KERNEL_ASSERT2(idx < N);
       dst[index] = src[idx * K + k];
     }
   } else {
@@ -33,13 +36,13 @@
       //     idx = i;
       //   }
       // }
-      // CUDA_KERNEL_ASSERT(idx >= 0);
-      // CUDA_KERNEL_ASSERT(idx < N);
+      // CUDA_KERNEL_ASSERT2(idx >= 0);
+      // CUDA_KERNEL_ASSERT2(idx < N);
       // dst[index] = src[idx * K + k];
 
       int idx = indices[n];
-      CUDA_KERNEL_ASSERT(idx >= 0);
-      CUDA_KERNEL_ASSERT(idx < N);
+      CUDA_KERNEL_ASSERT2(idx >= 0);
+      CUDA_KERNEL_ASSERT2(idx < N);
       dst[idx * K + k] = src[index];
     }
   }
@@ -64,17 +67,17 @@
   auto* Y = Output(0, X.sizes(), at::dtype<float>());
 
   if (X.dim32(0) > 0) {
-    BatchPermutationKernel<true>
-        <<<CAFFE_GET_BLOCKS(X.numel()),
+    TORCH_DSA_KERNEL_LAUNCH(
+    BatchPermutationKernel<true>,
+        CAFFE_GET_BLOCKS(X.numel()),
            CAFFE_CUDA_NUM_THREADS,
            0,
-           context_.cuda_stream()>>>(
+           context_.stream(),
             X.dim32(0),
             X.numel() / X.dim32(0),
             X.data<float>(),
             indices.data<int>(),
             Y->mutable_data<float>());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
   }
   return true;
 }
@@ -86,17 +89,17 @@
   auto* dX = Output(0, dY.sizes(), at::dtype<float>());
 
   if (dY.dim32(0) > 0) {
-    BatchPermutationKernel<false>
-        <<<CAFFE_GET_BLOCKS(dY.numel()),
+    TORCH_DSA_KERNEL_LAUNCH(
+    BatchPermutationKernel<false>,
+        CAFFE_GET_BLOCKS(dY.numel()),
            CAFFE_CUDA_NUM_THREADS,
            0,
-           context_.cuda_stream()>>>(
+           context_.stream(),
             dY.dim32(0),
             dY.numel() / dY.dim32(0),
             dY.data<float>(),
             indices.data<int>(),
             dX->mutable_data<float>());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
   }
   return true;
 }
diff --git a/caffe2/operators/boolean_unmask_ops.cu b/caffe2/operators/boolean_unmask_ops.cu
index 9128c4f..67decf6 100644
--- a/caffe2/operators/boolean_unmask_ops.cu
+++ b/caffe2/operators/boolean_unmask_ops.cu
@@ -3,6 +3,8 @@
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/operators/boolean_unmask_ops.h"
 
+#include <c10/cuda/CUDADeviceAssertion.h>
+
 namespace caffe2 {
 
 namespace {
@@ -11,7 +13,8 @@
     const int numMasks,
     const int maskSize,
     int* indices,
-    bool* const masks[]) {
+    bool* const masks[],
+    TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(i, maskSize) {
     for (int j = 0; j < numMasks; ++j) {
       if (masks[j][i]) {
@@ -19,7 +22,7 @@
         return;
       }
     }
-    CUDA_KERNEL_ASSERT(false);
+    CUDA_KERNEL_ASSERT2(false);
   }
 }
 
@@ -30,7 +33,8 @@
     const int* indices,
     char* const values[],
     int* valueSizes,
-    char* dest) {
+    char* dest,
+    TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(j, numMasks) {
     int k = 0;
     for (int i = 0; i < maskSize; ++i) {
@@ -41,7 +45,7 @@
         ++k;
       }
     }
-    CUDA_KERNEL_ASSERT(valueSizes[j] == k);
+    CUDA_KERNEL_ASSERT2(valueSizes[j] == k);
   }
 }
 
@@ -88,20 +92,21 @@
     ReinitializeTensor(&indices_, {maskSize}, at::dtype<int>().device(CUDA));
     auto* indicesData = indices_.mutable_data<int>();
 
-    ComputeIndicesKernel<<<
+    TORCH_DSA_KERNEL_LAUNCH(
+    ComputeIndicesKernel,
         std::min(maskSize, CAFFE_MAXIMUM_NUM_BLOCKS),
         CAFFE_CUDA_NUM_THREADS,
         0,
-        context_.cuda_stream()>>>(
+        context_.stream(),
         numMasks, maskSize, indicesData, masks_.data<bool*>());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
 
     auto* valueSizesData = valueSizes_.mutable_data<int>();
-    FillValuesKernel<<<
+    TORCH_DSA_KERNEL_LAUNCH(
+        FillValuesKernel,
         std::min(numMasks, CAFFE_MAXIMUM_NUM_BLOCKS),
         CAFFE_CUDA_NUM_THREADS,
         0,
-        context_.cuda_stream()>>>(
+        context_.stream(),
         numMasks,
         maskSize,
         meta.itemsize(),
@@ -109,7 +114,6 @@
         values_.data<char*>(),
         valueSizesData,
         dest);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
 
     return true;
   }
diff --git a/caffe2/operators/cross_entropy_op.cu b/caffe2/operators/cross_entropy_op.cu
index a2734e4..ff91ce5 100644
--- a/caffe2/operators/cross_entropy_op.cu
+++ b/caffe2/operators/cross_entropy_op.cu
@@ -5,15 +5,16 @@
 #include "caffe2/operators/cross_entropy_op.h"
 #include "caffe2/operators/operator_fallback_gpu.h"
 #include "caffe2/utils/cub_namespace.cuh"
+#include <c10/cuda/CUDADeviceAssertion.h>
 
 namespace caffe2 {
 
 namespace {
 __global__ void LabelCrossEntropyKernel(
     const int N, const int D, const float* Xdata, const int* labeldata,
-    const float log_threshold, float* Ydata) {
+    const float log_threshold, float* Ydata, TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(i, N) {
-    CUDA_KERNEL_ASSERT(labeldata[i] >= 0 && labeldata[i] < D);
+    CUDA_KERNEL_ASSERT2(labeldata[i] >= 0 && labeldata[i] < D);
     Ydata[i] = -logf(fmaxf(Xdata[i * D + labeldata[i]], log_threshold));
   }
 }
@@ -44,18 +45,18 @@
       (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == 1));
   CAFFE_ENFORCE_EQ(label.dim32(0), N);
   auto* Y = Output(0, vector<int64_t>(size_t(1), N), at::dtype<float>());
-  LabelCrossEntropyKernel<<<
+  TORCH_DSA_KERNEL_LAUNCH(
+      LabelCrossEntropyKernel,
       CAFFE_GET_BLOCKS(N),
       CAFFE_CUDA_NUM_THREADS,
       0,
-      context_.cuda_stream()>>>(
+      context_.stream(),
       N,
       D,
       X.data<float>(),
       label.data<int>(),
       kLOG_THRESHOLD(),
       Y->template mutable_data<float>());
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
 
   return true;
 }
diff --git a/caffe2/operators/segment_reduction_op_gpu.cu b/caffe2/operators/segment_reduction_op_gpu.cu
index 6985c3c..e669680 100644
--- a/caffe2/operators/segment_reduction_op_gpu.cu
+++ b/caffe2/operators/segment_reduction_op_gpu.cu
@@ -447,8 +447,9 @@
         size_t smem = sizeof(T) * post * multiple;
 
         // calling cuda kernel with ExactBlock = true, Average = false
-        sparse_length_sum_kernel<InType, T, IndexType, true, false>
-            <<<len_length, block, smem, context_.cuda_stream()>>>(
+        TORCH_DSA_KERNEL_LAUNCH(
+        (sparse_length_sum_kernel<InType, T, IndexType, true, false>)
+            ,len_length, block, smem, context_.stream(),
                 in_data,
                 out_data,
                 prefix_sum_length_data,
@@ -457,11 +458,11 @@
                 post,
                 len_length,
                 dataToReduceSize);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
       } else {
         // calling cuda kernel with ExactBlock = false, Average = false
-        sparse_length_sum_kernel<InType, T, IndexType, false, false>
-            <<<len_length, maxThreads, 0, context_.cuda_stream()>>>(
+        TORCH_DSA_KERNEL_LAUNCH(
+        (sparse_length_sum_kernel<InType, T, IndexType, false, false>)
+            ,len_length, maxThreads, 0, context_.stream(),
                 in_data,
                 out_data,
                 prefix_sum_length_data,
@@ -470,7 +471,6 @@
                 post,
                 len_length,
                 dataToReduceSize);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
       }
     } else {
       const T* in_data = dataInput.template data<T>();
@@ -584,8 +584,9 @@
         dim3 block(post, multiple);
         size_t smem = sizeof(T) * post * multiple;
         // calling cuda kernel with ExactBlock = true, Average = true
-        sparse_length_sum_kernel<InType, T, IndexType, true, true>
-            <<<len_length, block, smem, context_.cuda_stream()>>>(
+        TORCH_DSA_KERNEL_LAUNCH(
+        (sparse_length_sum_kernel<InType, T, IndexType, true, true>)
+            ,len_length, block, smem, context_.stream(),
                 in_data,
                 out_data,
                 prefix_sum_length_data,
@@ -594,11 +595,11 @@
                 post,
                 len_length,
                 dataToReduceSize);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
       } else {
         // calling cuda kernel with ExactBlock = false, Average = true
-        sparse_length_sum_kernel<InType, T, IndexType, false, true>
-            <<<len_length, maxThreads, 0, context_.cuda_stream()>>>(
+        TORCH_DSA_KERNEL_LAUNCH(
+        (sparse_length_sum_kernel<InType, T, IndexType, false, true>)
+            ,len_length, maxThreads, 0, context_.stream(),
                 in_data,
                 out_data,
                 prefix_sum_length_data,
@@ -607,7 +608,6 @@
                 post,
                 len_length,
                 dataToReduceSize);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
       }
     } else {
       const T* in_data = dataInput.template data<T>();
diff --git a/caffe2/operators/segment_reduction_op_gpu.cuh b/caffe2/operators/segment_reduction_op_gpu.cuh
index bb3f3be..399fa36 100644
--- a/caffe2/operators/segment_reduction_op_gpu.cuh
+++ b/caffe2/operators/segment_reduction_op_gpu.cuh
@@ -3,6 +3,7 @@
 #include <cub/device/device_reduce.cuh>
 #include <cub/device/device_scan.cuh>
 #include "caffe2/core/context_gpu.h"
+#include <c10/cuda/CUDADeviceAssertion.h>
 
 
 #if defined(USE_ROCM)
@@ -68,14 +69,14 @@
     int N,
     int post,
     int len_length,
-    int len_indices) {
+    int len_indices, TORCH_DSA_KERNEL_ARGS) {
   // len_length blocks
   int group = blockIdx.x;
 
   int start = group == 0 ? 0 : prefix_sum_length_data[group - 1];
   int end = prefix_sum_length_data[group];
-  CUDA_KERNEL_ASSERT(start <= len_indices);
-  CUDA_KERNEL_ASSERT(end <= len_indices);
+  CUDA_KERNEL_ASSERT2(start <= len_indices);
+  CUDA_KERNEL_ASSERT2(end <= len_indices);
 
   struct SharedMemory<OutType> smem;
   OutType* reduceVals = smem.getPointer();
diff --git a/caffe2/operators/sequence_ops.cu b/caffe2/operators/sequence_ops.cu
index 2ceb523..0fcf204 100644
--- a/caffe2/operators/sequence_ops.cu
+++ b/caffe2/operators/sequence_ops.cu
@@ -2,6 +2,7 @@
 
 #include <cub/cub.cuh>
 #include "caffe2/utils/cub_namespace.cuh"
+#include <c10/cuda/CUDADeviceAssertion.h>
 
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/operators/sequence_ops.h"
@@ -24,7 +25,8 @@
     const T* padding_end_ptr,
     int end_padding_width_blocks,
     T* out,
-    int32_t* lengths_out) {
+    int32_t* lengths_out,
+    TORCH_DSA_KERNEL_ARGS) {
   int element_idx = blockIdx.x;
   int prior_padding =
       element_idx * (start_padding_width_blocks + end_padding_width_blocks);
@@ -39,7 +41,7 @@
     in_start_idx = lengths_prefix_sum[element_idx] - len_blocks;
   } else {
     // Only one element, use the outer size
-    CUDA_KERNEL_ASSERT(lengths_size == 1);
+    CUDA_KERNEL_ASSERT2(lengths_size == 1);
     len_blocks = outer_size;
     in_start_idx = 0;
   }
@@ -86,7 +88,8 @@
     int start_padding_width_blocks,
     int end_padding_width_blocks,
     T* out,
-    int32_t* lengths_out) {
+    int32_t* lengths_out,
+    TORCH_DSA_KERNEL_ARGS) {
   int element_idx = blockIdx.x;
   int prior_padding =
       element_idx * (start_padding_width_blocks + end_padding_width_blocks);
@@ -101,7 +104,7 @@
     in_start_idx = lengths_prefix_sum[element_idx] - len_blocks;
   } else {
     // Only one element, use the outer size
-    CUDA_KERNEL_ASSERT(lengths_size == 1);
+    CUDA_KERNEL_ASSERT2(lengths_size == 1);
     len_blocks = outer_size;
     in_start_idx = 0;
   }
@@ -214,8 +217,9 @@
   }
 
   // Compute the padding using the accumulated lengths
-  AddPaddingKernel<T>
-      <<<lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+  TORCH_DSA_KERNEL_LAUNCH(
+          AddPaddingKernel<T>,
+          lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.stream(),
           in_ptr,
           block_size,
           lengths_size,
@@ -227,7 +231,6 @@
           endPaddingWidth_,
           out_ptr,
           lengths_out_ptr);
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
 
   return true;
 }
@@ -282,8 +285,9 @@
   }
 
   // Compute the padding using the accumulated lengths
-  RemovePaddingKernel<T>
-      <<<lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+  TORCH_DSA_KERNEL_LAUNCH(
+          RemovePaddingKernel<T>,
+          lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.stream(),
           in_ptr,
           block_size,
           lengths_size,
@@ -293,7 +297,6 @@
           endPaddingWidth_,
           out_ptr,
           lengths_out_ptr);
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
 
   return true;
 }
diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu
index ebf0700..75d8899 100644
--- a/caffe2/operators/softmax_ops.cu
+++ b/caffe2/operators/softmax_ops.cu
@@ -6,6 +6,7 @@
 #include "caffe2/operators/softmax_with_loss_op.h"
 #include "caffe2/operators/spatial_softmax_with_loss_op.h"
 #include "caffe2/utils/cub_namespace.cuh"
+#include <c10/cuda/CUDADeviceAssertion.h>
 
 namespace caffe2 {
 
@@ -17,9 +18,10 @@
     const float* logPdata,
     const int* labeldata,
     const float* weights,
-    float* Ydata) {
+    float* Ydata,
+    TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(i, N) {
-    CUDA_KERNEL_ASSERT(labeldata[i] >= 0 && labeldata[i] < D);
+    CUDA_KERNEL_ASSERT2(labeldata[i] >= 0 && labeldata[i] < D);
     float weight = weights ? weights[i] : 1.0;
     Ydata[i] = -logPdata[i * D + labeldata[i]] * weight;
   }
@@ -59,7 +61,8 @@
     const float* Pdata,
     const float* labeldata,
     const float* weights,
-    float* Ydata) {
+    float* Ydata,
+    TORCH_DSA_KERNEL_ARGS) {
   typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
   __shared__ typename BlockReduce::TempStorage temp_storage;
 
@@ -69,7 +72,7 @@
     float total_prob = 0.0;
     for (int j = threadIdx.x; j < D; j += blockDim.x) {
       int idx = i * D + j;
-      CUDA_KERNEL_ASSERT(labeldata[idx] >= 0);
+      CUDA_KERNEL_ASSERT2(labeldata[idx] >= 0);
       total_prob += labeldata[idx];
       sum += -logf(fmaxf(Pdata[idx], FLT_MIN)) * labeldata[idx] * weight;
     }
@@ -79,7 +82,7 @@
     if (threadIdx.x == 0) {
       Ydata[i] = tot;
       // Sanity check
-      CUDA_KERNEL_ASSERT(fabsf(1.0 - total_prob_sum) < 1e-5f);
+      CUDA_KERNEL_ASSERT2(fabsf(1.0 - total_prob_sum) < 1e-5f);
     }
     __syncthreads();
   }
@@ -150,7 +153,8 @@
     const int* label_data,
     const float* weights,
     float* loss_data,
-    float* weight_data) {
+    float* weight_data,
+    TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(index, N * W * H) {
     int x = index % W;
     int y = (index / W) % H;
@@ -158,7 +162,7 @@
     const int label = static_cast<int>(label_data[index]);
 
     if (label != DONTCARE) {
-      CUDA_KERNEL_ASSERT(label >= 0 && label < D);
+      CUDA_KERNEL_ASSERT2(label >= 0 && label < D);
       float weight = (weights == NULL ? 1.0 : weights[index]);
       loss_data[index] =
           -logf(
@@ -180,7 +184,8 @@
     const int* label_data,
     const float* weights,
     float* dX_data,
-    float* weights_) {
+    float* weights_,
+    TORCH_DSA_KERNEL_ARGS) {
   CUDA_1D_KERNEL_LOOP(index, N * W * H) {
     int x = index % W;
     int y = (index / W) % H;
@@ -356,36 +361,36 @@
       &context_);
   // Compute label xent loss per example
   if (!label_prob_mode_) {
-    LabelCrossEntropyKernel<<<
+    TORCH_DSA_KERNEL_LAUNCH(
+        LabelCrossEntropyKernel,
         CAFFE_GET_BLOCKS(N),
         CAFFE_CUDA_NUM_THREADS,
         0,
-        context_.cuda_stream()>>>(
+        context_.stream(),
         N,
         D,
         P->data<float>(),
         T.data<int>(),
         weights,
         losses_.mutable_data<float>());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
 
     // Since we had logarithmic output, we need to exponentiate
     // them again.
     math::Exp<float, CUDAContext>(
         N * D, P->data<float>(), P->template mutable_data<float>(), &context_);
   } else {
-    ProbCrossEntropyKernel<<<
+    TORCH_DSA_KERNEL_LAUNCH(
+        ProbCrossEntropyKernel,
         std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
         CAFFE_CUDA_NUM_THREADS,
         0,
-        context_.cuda_stream()>>>(
+        context_.stream(),
         N,
         D,
         P->data<float>(),
         T.data<float>(),
         weights,
         losses_.mutable_data<float>());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
   }
 
   float total_weight = N;
@@ -477,11 +482,12 @@
   math::Set<float, CUDAContext>(
       1, 0.0f, total_weight_ptr_.mutable_data<float>(), &context_);
 
-  SpatialCrossEntropyLossKernel<<<
+  TORCH_DSA_KERNEL_LAUNCH(
+      SpatialCrossEntropyLossKernel,
       CAFFE_GET_BLOCKS(N * W * H),
       CAFFE_CUDA_NUM_THREADS,
       0,
-      context_.cuda_stream()>>>(
+      context_.stream(),
       N,
       D,
       W,
@@ -491,7 +497,6 @@
       weights,
       losses_.mutable_data<float>(),
       weights_.mutable_data<float>());
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
 
   // Somewhat awkward scalar passing from device to host
   float h_total_weight;
@@ -698,13 +703,13 @@
   math::Set<float, CUDAContext>(
       1, 0.0f, total_weight_ptr_.mutable_data<float>(), &context_);
 
-  SpatialSoftmaxLossGradientKernel<<<
+  TORCH_DSA_KERNEL_LAUNCH(
+      SpatialSoftmaxLossGradientKernel,
       CAFFE_GET_BLOCKS(N * W * H),
       CAFFE_CUDA_NUM_THREADS,
       0,
-      context_.cuda_stream()>>>(
+      context_.stream(),
       N, D, W, H, label_data, weights, dX_data, weights_.mutable_data<float>());
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
 
   math::Sum<float, CUDAContext>(
       weights_.numel(),
diff --git a/caffe2/operators/top_k_radix_selection.cuh b/caffe2/operators/top_k_radix_selection.cuh
index 710c6ff..f71fbd0 100644
--- a/caffe2/operators/top_k_radix_selection.cuh
+++ b/caffe2/operators/top_k_radix_selection.cuh
@@ -6,6 +6,7 @@
 #include "caffe2/utils/GpuScanUtils.cuh"
 #include "caffe2/utils/GpuAtomics.cuh"
 #include "caffe2/utils/math.h"
+#include <c10/cuda/CUDADeviceAssertion.h>
 #include <cuda_runtime.h>
 
 namespace caffe2 {
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 473d645..a2448a9 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -10,6 +10,7 @@
 #include "caffe2/operators/flatten_op.h"
 #include "caffe2/utils/GpuAtomics.cuh"
 #include "caffe2/utils/math.h"
+#include <c10/cuda/CUDADeviceAssertion.h>
 
 namespace caffe2 {