Remove sync in embedding (#70943)

Summary:
This together with https://github.com/pytorch/pytorch/pull/66580 and https://github.com/pytorch/pytorch/pull/68376 will remove all syncs in embedding.

This PR includes https://github.com/pytorch/pytorch/pull/68376, please review after merging https://github.com/pytorch/pytorch/pull/68376

This PR introduces perf regressions and increases memory usage:
- `exclusive_sum` is now computing the entire `numel` elements instead of `num_of_segments` elements, and the trailing `numel - num_of_segments` results will be discarded.
- Some memory allocation now needs `numel` spaces instead of `num_of_segments` or `num_of_partial_segments`.

These are the prices we must pay in order to get a sync-free implementation.

I haven't done any benchmark yet. I will do it later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70943

Reviewed By: H-Huang

Differential Revision: D34881660

Pulled By: ngimel

fbshipit-source-id: b0760fa33608c46cd4145ceb09878bf94a9f959d
(cherry picked from commit d959fa4783cfee84bf17c1fa6d0f5d6bde268d75)
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
index 6035f30..202f3ad 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
@@ -39,7 +39,8 @@
 template <typename index_t>
 __global__
 void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
-                              int64_t num_of_segments, int64_t numel) {
+                              int64_t *num_of_segments_ptr, int64_t numel) {
+  int64_t num_of_segments = *num_of_segments_ptr;
   const int id = blockIdx.x * blockDim.x + threadIdx.x;
   if(id < num_of_segments) {
     const int64_t idx_start = segment_offsets[id];
@@ -56,7 +57,8 @@
         const index_t *partials_per_segment,
         const index_t *partials_per_segment_offset,
         const index_t *segment_offsets,
-        int64_t num_of_segments) {
+        int64_t *num_of_segments_ptr) {
+  int64_t num_of_segments = *num_of_segments_ptr;
   const int id = blockIdx.x * blockDim.x + threadIdx.x;
   if(id < num_of_segments) {
     index_t idx = partials_per_segment_offset[id];
@@ -75,10 +77,11 @@
     index_t *offset2bag, index_t *count, ptrdiff_t numel,
     int64_t stride, int mode_mean, const index_t *bag_size,
     scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
-    index_t* segment_offsets, int64_t num_of_segments,
+    index_t* segment_offsets, int64_t *num_of_segments_ptr,
     acc_type<scalar_t, true> *grad_weight_per_segment,
     const int64_t stride_warped) {
 
+  int64_t num_of_segments = *num_of_segments_ptr;
   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
   const int id = gid / stride_warped;
   const int startFeature = gid % stride_warped;
@@ -119,10 +122,11 @@
     ptrdiff_t numel,
     int64_t stride,
     index_t* segment_offsets,
-    int64_t num_of_segments,
+    int64_t *num_of_segments_ptr,
     acc_type<scalar_t, true> *grad_weight_per_segment,
     const int64_t stride_warped) {
 
+  int64_t num_of_segments = *num_of_segments_ptr;
   using accscalar_t = acc_type<scalar_t, true>;
   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
   const int id = gid / stride_warped;
@@ -149,12 +153,14 @@
 template <typename scalar_t, typename index_t>
 __global__ void sum_and_scatter(
     index_t *input, scalar_t *gradWeight, int64_t stride,
-    index_t* segment_offsets, int64_t num_of_segments,
+    index_t* segment_offsets, int64_t *num_of_segments_ptr,
     const acc_type<scalar_t, true> *grad_weight_per_segment,
-    const index_t *segment_sizes_offsets, int64_t num_of_partial_segments,
+    const index_t *segment_sizes_offsets, int64_t *num_of_partial_segments_ptr,
     const int64_t padding_idx,
     const int64_t stride_warped) {
 
+  int64_t num_of_segments = *num_of_segments_ptr;
+  int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
   const int id = gid / stride_warped;
   const int startFeature = gid % stride_warped;
@@ -177,6 +183,17 @@
   }
 }
 
+template<typename index_t>
+__global__ void compute_num_of_partial_segments(index_t *partials_per_segment, index_t *partials_per_segment_offset, int64_t *num_of_segments_ptr, int64_t *output) {
+  int64_t num_of_segments = *num_of_segments_ptr;
+  *output = partials_per_segment[num_of_segments-1] +
+            partials_per_segment_offset[num_of_segments-1];
+}
+
+__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
+  *num_of_segments_ptr = num_of_segments;
+}
+
 } // anon namespace
 
 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
@@ -207,10 +224,13 @@
   // be summarized.
   // Unit: index in `sorted_indices` and `orig_indices`
   auto segment_offsets = at::empty({numel}, orig_indices.options());
-  int64_t num_of_segments;
+  auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
+  int64_t *num_of_segments_ptr = num_of_segments_tensor.data_ptr<int64_t>();
 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
-    num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
+    int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
+    write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
   });
 #else
   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
@@ -218,8 +238,7 @@
     cuda::cub::unique_by_key(
       sorted_indices.data_ptr<index_t>(), thrust::make_counting_iterator(0),
       nullptr, segment_offsets.data_ptr<index_t>(),
-      num_of_segments_tensor.data_ptr<int64_t>(), sorted_indices.numel());
-    num_of_segments = num_of_segments_tensor.item<int64_t>();
+      num_of_segments_ptr, sorted_indices.numel());
   });
 #endif
 
@@ -227,12 +246,12 @@
     // We split the segments up into sizes of `NROWS_PER_THREAD`
     // Compute the number partial-segments per segment (some partial-segments
     // may not be the full `NROWS_PER_THREAD` number of rows)
-    auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
+    auto partials_per_segment = at::empty({numel}, orig_indices.options());
     {
-      krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+      krn_partials_per_segment<<<ceil_div(numel, 32), 32, 0, stream>>> (
               partials_per_segment.data_ptr<index_t>(),
               segment_offsets.data_ptr<index_t>(),
-              num_of_segments,
+              num_of_segments_ptr,
               numel);
       C10_CUDA_KERNEL_LAUNCH_CHECK();
     }
@@ -241,33 +260,38 @@
     // of each partial-segment in `sorted_indices`, we need to compute the
     // start position of each _segment_ in `partial_segment_offset`.
     // Unit: index in `partial_segment_offset`
-    auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
+    auto partials_per_segment_offset = at::empty({numel}, orig_indices.options());
     cuda::cub::exclusive_sum(
         partials_per_segment.data_ptr<index_t>(),
         partials_per_segment_offset.data_ptr<index_t>(),
-        num_of_segments);
+        numel);
 
     // The total number of partial-segments is the sum of `partials_per_segment_offset`
-    const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
-            partials_per_segment_offset[num_of_segments-1].item<index_t>();
+    auto num_of_partial_segments_tensor = at::empty({}, grad.options().dtype(kLong));
+    int64_t *num_of_partial_segments_ptr = num_of_partial_segments_tensor.data_ptr<int64_t>();
+    compute_num_of_partial_segments<index_t><<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(
+      partials_per_segment.data_ptr<index_t>(),
+      partials_per_segment_offset.data_ptr<index_t>(),
+      num_of_segments_ptr, num_of_partial_segments_ptr);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
 
     // Now we can compute the start position of each partial-segment
     // Unit: index in `sorted_indices` and `orig_indices`
-    auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
+    auto partial_segment_offset = at::empty({numel}, orig_indices.options());
     {
-      krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+      krn_partial_segment_offset<<<ceil_div(numel, 32), 32, 0, stream>>> (
               partial_segment_offset.data_ptr<index_t>(),
               partials_per_segment.data_ptr<index_t>(),
               partials_per_segment_offset.data_ptr<index_t>(),
               segment_offsets.data_ptr<index_t>(),
-              num_of_segments);
+              num_of_segments_ptr);
       C10_CUDA_KERNEL_LAUNCH_CHECK();
     }
 
     const int warp_size = at::cuda::warp_size();
     const int stride_warped = ceil_div(stride, warp_size)*warp_size;
     const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
-    const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
+    const int grid = ceil_div(numel*stride_warped, block);
 
     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
       grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
@@ -280,7 +304,7 @@
         } else {
             op = grad.options();
         }
-        auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
+        auto grad_weight_per_segment = at::empty({numel, stride}, op);
         // Compute the sum of each partial-segment and handle bags
         if (offset2bag.defined()) {
               compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
@@ -292,7 +316,7 @@
                 per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
                 per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
                 partial_segment_offset.data_ptr<index_t>(),
-                num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
+                num_of_partial_segments_ptr, grad_weight_per_segment.data_ptr<partial_weight_t>(),
                 stride_warped);
               C10_CUDA_KERNEL_LAUNCH_CHECK();
         } else {
@@ -302,7 +326,7 @@
                 count.defined() ? count.data_ptr<index_t>() : nullptr,
                 numel, stride,
                 partial_segment_offset.data_ptr<index_t>(),
-                num_of_partial_segments,
+                num_of_partial_segments_ptr,
                 grad_weight_per_segment.data_ptr<partial_weight_t>(),
                 stride_warped);
               C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -310,15 +334,15 @@
 
         // Finally, we sum all the partial-sums and scatter them
         // into `grad_weight`.
-        const int grid2 = ceil_div(num_of_segments*stride_warped, block);
+        const int grid2 = ceil_div(numel*stride_warped, block);
             sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
               sorted_indices.data_ptr<index_t>(),
               grad_weight.data_ptr<scalar_t>(),
               stride,
               segment_offsets.data_ptr<index_t>(),
-              num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
+              num_of_segments_ptr, grad_weight_per_segment.data_ptr<partial_weight_t>(),
               partials_per_segment_offset.data_ptr<index_t>(),
-              num_of_partial_segments,
+              num_of_partial_segments_ptr,
               padding_idx,
               stride_warped);
         C10_CUDA_KERNEL_LAUNCH_CHECK();