Optimization of the Embedding and Embedding-Bag CUDA Kernel (#22016)

Summary:
Re-implementation of the `embedding_dense_backward_cuda()` and the `embedding_bag_backward_cuda_sum_avg()` functions.

#### Performance
Running a [Mortgage Workflow](https://github.com/EvenOldridge/MortgageWorkflowA) with a block size of 100K on a DXG-2 (single GPU), we see a 270% speedup:
```
Original version:    370,168 example/s
Optimized version: 1,034,228 example/s
```
The original version is bounded by the `EmbeddingBag_accGradParametersKernel_sum_avg`, which takes 70% of the CUDA execution time. In the optimized version, the optimized kernel now takes only 17% of the time.

#### Greater Numerical Stability
An added benefit is greater numerical stability. Instead of doing a flat sum where a single variable are used to accumulate the weights, this code uses two-steps where each GPU-thread computes a sub-result defined by `NROWS_PER_THREAD` before the final result are accumulated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22016

Differential Revision: D15944339

Pulled By: mrshenli

fbshipit-source-id: 398d5f48826a017fc4b31c24c3f8b56d01830bf0
diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu
index e7a76ad..1bee496 100644
--- a/aten/src/ATen/native/cuda/Embedding.cu
+++ b/aten/src/ATen/native/cuda/Embedding.cu
@@ -12,6 +12,8 @@
 #include <thrust/execution_policy.h>
 #include <thrust/unique.h>
 
+#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
+
 
 namespace at { namespace native {
 
@@ -231,14 +233,12 @@
 
   auto num_indices = indices.numel();
   auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
-  auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
-
-  int64_t stride = grad_weight.stride(0);
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   if (num_indices <= 768 && !scale_grad_by_freq) {
     auto indices_contig = indices.contiguous();
-
+    auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
+    int64_t stride = grad_weight.stride(0);
     dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE));
     dim3 block(WARP_SIZE, BLOCKDIMY);
 
@@ -323,23 +323,8 @@
     );
   }
 
-  dim3 grid(THCCeilDiv(num_indices, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
-  dim3 block(WARP_SIZE, 4);
-
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "embedding_backward", [&] {
-    embedding_backward_kernel<<<grid, block, 0, stream>>>(
-      sorted_indices.data<int64_t>(),
-      orig_indices.data<int64_t>(),
-      grad.data<scalar_t>(),
-      grad_weight.data<scalar_t>(),
-      count.defined() ? count.data<int64_t>() : nullptr,
-      num_indices,
-      stride,
-      padding_idx);
-  });
-  THCudaCheck(cudaGetLastError());
-
-  return grad_weight;
+  return embedding_backward_cuda_kernel(grad, orig_indices,
+      sorted_indices, count, num_weights, padding_idx);
 }
 
 Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
new file mode 100644
index 0000000..5924350
--- /dev/null
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
@@ -0,0 +1,326 @@
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/AccumulateType.h>
+
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCTensorSort.cuh>
+#include <THC/THCThrustAllocator.cuh>
+#include <THC/THCAtomics.cuh>
+
+#include <thrust/execution_policy.h>
+#include <thrust/unique.h>
+#include <thrust/device_vector.h>
+
+namespace at {
+namespace native {
+
+namespace {
+
+// The maximum block size in CUDA
+constexpr int MAX_BLOCK_SIZE = 1024;
+/* This code computes the sum of the weights in two-steps:
+  1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
+  2) Each partial-sum from 1) are summed and scatter into `grad_weight`
+
+  Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
+  kernel execution. If it is high, the size of the thread blocks will be
+  too small to achieve good occupancy. Similarly, a very low value will
+  make the size of the thread blocks in the final sum in step 2) too small.
+*/
+constexpr int NROWS_PER_THREAD = 10;
+
+#ifdef __HIP_PLATFORM_HCC__
+    constexpr int WARP_SIZE = 64;
+#else
+    constexpr int WARP_SIZE = 32;
+#endif
+
+// Fast ceil division (no overflow checking)
+__host__ __device__ __forceinline__
+int64_t ceil_div(int64_t x, int64_t y) {
+  return (x + y - 1) / y;
+}
+
+__global__
+void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
+                              int64_t num_of_segments, int64_t numel) {
+  const int id = blockIdx.x * blockDim.x + threadIdx.x;
+  if(id < num_of_segments) {
+    const int64_t idx_start = segment_offsets[id];
+    const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+    const int64_t size = idx_end - idx_start;
+    ret[id] = ceil_div(size, NROWS_PER_THREAD);
+  }
+}
+
+__global__
+void krn_partial_segment_offset(
+        int64_t *ret,
+        const int64_t *partials_per_segment,
+        const int64_t *partials_per_segment_offset,
+        const int64_t *segment_offsets,
+        int64_t num_of_segments) {
+  const int id = blockIdx.x * blockDim.x + threadIdx.x;
+  if(id < num_of_segments) {
+    int64_t idx = partials_per_segment_offset[id];
+    const int64_t num_partials = partials_per_segment[id];
+    const int64_t segment_offset = segment_offsets[id];
+    for (int64_t i=0; i<num_partials; ++i) {
+      ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
+    }
+  }
+}
+
+
+template <typename scalar_t>
+__global__ void compute_grad_weight_bags(
+    int64_t *indices, scalar_t *gradOutput,
+    int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
+    int64_t stride, int mode_mean, const int64_t *bag_size,
+    scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
+    int64_t* segment_offsets, int64_t num_of_segments, scalar_t *grad_weight_per_segment,
+    const int64_t stride_warped) {
+
+  const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+  const int id = gid / stride_warped;
+  const int startFeature = gid % stride_warped;
+  if (startFeature >= stride) {
+    return;
+  }
+  if (id >= num_of_segments) {
+    return;
+  }
+  const int idx_begin = segment_offsets[id];
+  const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+
+  acc_type<scalar_t, true> weight = 0;
+  for (int idx=idx_begin; idx < idx_end; ++idx) {
+    const int origRow = indices[idx];
+    const int seq_number = offset2bag[origRow];
+    const int gradOutputRow = seq_number * stride;
+
+    acc_type<scalar_t, true> scale = count ? 1.0 / count[idx] : 1.0;
+    if (per_sample_weights) {
+      scale *= per_sample_weights[origRow * per_sample_weights_stride];
+    }
+
+    acc_type<scalar_t, true> gradient = gradOutput[gradOutputRow + startFeature];
+    if (mode_mean) {
+      gradient /= bag_size[seq_number];
+    }
+    weight += gradient * scale;
+  }
+  grad_weight_per_segment[id * stride + startFeature] = weight;
+}
+
+template <typename scalar_t>
+__global__ void compute_grad_weight(
+    int64_t *indices,
+    scalar_t *gradOutput,
+    int64_t *count,
+    ptrdiff_t numel,
+    int64_t stride,
+    int64_t* segment_offsets,
+    int64_t num_of_segments,
+    scalar_t *grad_weight_per_segment,
+    int padding_idx,
+    const int64_t stride_warped) {
+
+  using accscalar_t = acc_type<scalar_t, true>;
+  const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+  const int id = gid / stride_warped;
+  const int startFeature = gid % stride_warped;
+  if (startFeature >= stride) {
+    return;
+  }
+  if (id >= num_of_segments) {
+    return;
+  }
+  const int idx_begin = segment_offsets[id];
+  const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+  if (idx_begin == padding_idx) {
+    return;
+  }
+
+  accscalar_t weight = 0;
+  for (int idx=idx_begin; idx < idx_end; ++idx) {
+    const int64_t target_row = indices[idx];
+    if (target_row != padding_idx) {
+      const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+      weight += gradOutput[target_row * stride + startFeature] * scale;
+    }
+  }
+  grad_weight_per_segment[id * stride + startFeature] = weight;
+}
+
+// This kernel assumes that all input tensors are contiguous.
+template <typename scalar_t>
+__global__ void sum_and_scatter(
+    int64_t *input, scalar_t *gradWeight, int64_t stride,
+    int64_t* segment_offsets, int64_t num_of_segments, const scalar_t *grad_weight_per_segment,
+    const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
+    const int64_t stride_warped) {
+
+  const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+  const int id = gid / stride_warped;
+  const int startFeature = gid % stride_warped;
+  if (startFeature >= stride) {
+    return;
+  }
+  if (id >= num_of_segments) {
+    return;
+  }
+
+  const int idx_begin = segment_sizes_offsets[id];
+  const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1];
+  acc_type<scalar_t, true> weight = 0;
+  for (int idx=idx_begin; idx < idx_end; ++idx) {
+    weight += grad_weight_per_segment[idx*stride + startFeature];
+  }
+  const int weightRow = input[segment_offsets[id]] * stride;
+  gradWeight[weightRow + startFeature] = weight;
+}
+
+} // anon namespace
+
+Tensor embedding_backward_cuda_kernel(
+        const Tensor &grad,
+        const Tensor &orig_indices,
+        const Tensor &sorted_indices,
+        const Tensor &count,
+        int64_t num_weights,
+        int padding_idx,
+        bool scale_grad_by_freq,
+        bool mode_mean,
+        const Tensor &offset2bag,
+        const Tensor &bag_size,
+        const Tensor &per_sample_weights) {
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+  auto policy = thrust::cuda::par(allocator).on(stream);
+  const ptrdiff_t numel = sorted_indices.numel();
+
+  auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
+  const int64_t stride = grad_weight.stride(0);
+
+  // Compute the number of segments and their start position so that we do not have to
+  // spawn a warp per index. In this context, a segment is a number of rows that should
+  // be summarized.
+  // Unit: index in `sorted_indices` and `orig_indices`
+  thrust::device_vector<int64_t> segment_offsets(numel);
+  int64_t num_of_segments;
+  {
+    auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data<int64_t>());
+    auto dummy = at::empty_like(sorted_indices);
+    auto dummy_dev = thrust::device_ptr<int64_t>(dummy.data<int64_t>());
+    auto ends = thrust::unique_by_key_copy(
+            policy,
+            sorted_indices_dev,
+            sorted_indices_dev + numel,
+            thrust::make_counting_iterator(0),
+            dummy_dev,
+            thrust::raw_pointer_cast(segment_offsets.data()));
+    num_of_segments = thrust::get<0>(ends) - dummy_dev;
+  }
+
+  // We split the segments up into sizes of `NROWS_PER_THREAD`
+  // Compute the number partial-segments per segment (some partial-segments 
+  // may not be the full `NROWS_PER_THREAD` number of rows)
+  thrust::device_vector<int64_t> partials_per_segment(num_of_segments);
+  {
+    krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+            thrust::raw_pointer_cast(partials_per_segment.data()),
+            thrust::raw_pointer_cast(segment_offsets.data()),
+            num_of_segments,
+            numel);
+  }
+
+  // In order to compute `partial_segment_offset`, which is the start index
+  // of each partial-segment in `sorted_indices`, we need to compute the
+  // start position of each _segment_ in `partial_segment_offset`.
+  // Unit: index in `partial_segment_offset`
+  thrust::device_vector<int64_t> partials_per_segment_offset(num_of_segments); 
+  thrust::exclusive_scan(
+          policy,
+          partials_per_segment.begin(),
+          partials_per_segment.end(),
+          partials_per_segment_offset.begin());
+
+  // The total number of partial-segments is the sum of `partials_per_segment_offset`
+  const int num_of_partial_segments = partials_per_segment[num_of_segments-1] +
+          partials_per_segment_offset[num_of_segments-1];
+
+  // Now we can compute the start position of each partial-segment
+  // Unit: index in `sorted_indices` and `orig_indices`
+  thrust::device_vector<int64_t> partial_segment_offset(num_of_partial_segments);
+  {
+    krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+            thrust::raw_pointer_cast(partial_segment_offset.data()),
+            thrust::raw_pointer_cast(partials_per_segment.data()),
+            thrust::raw_pointer_cast(partials_per_segment_offset.data()),
+            thrust::raw_pointer_cast(segment_offsets.data()),
+            num_of_segments);
+  }
+
+  auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, grad.options());
+  const int stride_warped = ceil_div(stride, WARP_SIZE)*WARP_SIZE;
+  const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
+  const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
+
+  // Compute the sum of each partial-segment and handle bags
+  if (offset2bag.defined()) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
+        compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
+          orig_indices.data<int64_t>(),
+          grad.data<scalar_t>(),
+          offset2bag.data<int64_t>(),
+          count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
+          mode_mean, bag_size.data<int64_t>(),
+          per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+          per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
+          thrust::raw_pointer_cast(partial_segment_offset.data()),
+          num_of_partial_segments, grad_weight_per_segment.data<scalar_t>(),
+          stride_warped);
+    });
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
+        compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
+          orig_indices.data<int64_t>(),
+          grad.data<scalar_t>(),
+          count.defined() ? count.data<int64_t>() : nullptr,
+          numel, stride,
+          thrust::raw_pointer_cast(partial_segment_offset.data()),
+          num_of_partial_segments,
+          grad_weight_per_segment.data<scalar_t>(),
+          padding_idx,
+          stride_warped);
+    });
+  }
+  THCudaCheck(cudaGetLastError());
+
+  // Finally, we sum all the partial-sums and scatter them
+  // into `grad_weight`.
+  const int grid2 = ceil_div(num_of_segments*stride_warped, block);
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+    grad.scalar_type(), "embedding_bag_backward_cuda_sum_and_scatter", [&] {
+      sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
+        sorted_indices.data<int64_t>(),
+        grad_weight.data<scalar_t>(),
+        stride,
+        thrust::raw_pointer_cast(segment_offsets.data()),
+        num_of_segments, grad_weight_per_segment.data<scalar_t>(),
+        thrust::raw_pointer_cast(partials_per_segment_offset.data()),
+        num_of_partial_segments, stride_warped);
+  });
+  THCudaCheck(cudaGetLastError());
+  return grad_weight;
+}
+
+}}
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
new file mode 100644
index 0000000..e483550
--- /dev/null
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
@@ -0,0 +1,36 @@
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/AccumulateType.h>
+
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCTensorSort.cuh>
+#include <THC/THCThrustAllocator.cuh>
+#include <THC/THCAtomics.cuh>
+
+#include <thrust/execution_policy.h>
+#include <thrust/unique.h>
+#include <thrust/device_vector.h>
+
+#pragma once
+
+namespace at {
+namespace native {
+
+Tensor embedding_backward_cuda_kernel(
+    const Tensor &grad,
+    const Tensor &orig_indices,
+    const Tensor &sorted_indices,
+    const Tensor &count,
+    int64_t num_weights,
+    int padding_idx = -1,
+    bool scale_grad_by_freq = false,
+    bool mode_mean = false,
+    const Tensor &offset2bag = Tensor(),
+    const Tensor &bag_size = Tensor(),
+    const Tensor &per_sample_weights = Tensor());
+
+}}
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index f0761bd..3f34d80 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -13,17 +13,22 @@
 
 #include <thrust/execution_policy.h>
 #include <thrust/unique.h>
+#include <thrust/device_vector.h>
 
-const int WARP_SIZE = 32;
-const int MODE_SUM = 0;
-const int MODE_MEAN = 1;
-const int MODE_MAX = 2;
+#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
 
 namespace at {
 namespace native {
 
 namespace {
 
+constexpr int MODE_SUM = 0;
+constexpr int MODE_MEAN = 1;
+constexpr int MODE_MAX = 2;
+
+constexpr int WARP_SIZE = 32;
+
+
 // This kernel assumes that all input tensors except `weight` and
 // per_sample_weights are contiguous.
 template <typename scalar_t>
@@ -104,85 +109,6 @@
   }
 }
 
-// FIXME: removed the accGradParametersKernelByFeature case present in
-// LookupTable. That kernel is faster at small sizes (<768 indices), which
-// does not need EmbeddingBag (LookupTable + Sum works fine), but would
-// still be nice to not be slow in that case.
-
-// This kernel assumes that all input tensors are contiguous.
-template <typename scalar_t>
-__global__ void EmbeddingBag_accGradParametersKernel_sum_avg(
-    int64_t *input, int64_t *indices, scalar_t *gradOutput,
-    scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
-    int64_t stride, int mode, const int64_t *bag_size,
-    scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
-
-  using accscalar_t = acc_type<scalar_t, true>;
-  int idx = blockIdx.x * 4 + threadIdx.y;
-
-  // Each warp is responsible for an input into the LookupTable.
-  // If the preceding input has the same as this input, then the warp
-  // exits immediately. The warp also processes subsequent inputs with the
-  // same value.  //
-  // Input Warp
-  // 1     <warp 1>
-  // 1     <warp 1> (<warp 2> exits without doing any work)
-  // 5     <warp 3>
-  // 8     <warp 4>
-
-  // Number of values proceessed by each thread (grain size)
-  const int SZ = 4;
-
-  if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
-    do {
-      const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
-      const int weightRow = ((int)input[idx]) * stride;
-
-      // Note: only this line changes from LookupTable_accgradParametersKernel
-      const int origRow = ((int)indices[idx]);
-      const int seq_number = offset2bag[origRow];
-      const int gradOutputRow = ((int)seq_number) * stride;
-
-      accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
-      if (per_sample_weights) {
-        scale *= per_sample_weights[origRow * per_sample_weights_stride];
-      }
-
-      accscalar_t gradient[SZ];
-      accscalar_t weight[SZ];
-
-#pragma unroll
-      for (int ii = 0; ii < SZ; ii++) {
-        int featureDim = startFeature + ii * WARP_SIZE;
-        if (featureDim < stride) {
-          gradient[ii] =
-              static_cast<accscalar_t>(gradOutput[gradOutputRow + featureDim]);
-          if (mode == MODE_MEAN) {
-            gradient[ii] /= bag_size[seq_number];
-          }
-          weight[ii] =
-              static_cast<accscalar_t>(gradWeight[weightRow + featureDim]);
-        }
-      }
-
-#pragma unroll
-      for (int ii = 0; ii < SZ; ii++) {
-        weight[ii] += gradient[ii] * scale;
-      }
-
-#pragma unroll
-      for (int ii = 0; ii < SZ; ii++) {
-        int featureDim = startFeature + ii * WARP_SIZE;
-        if (featureDim < stride) {
-          gradWeight[weightRow + featureDim] =
-              static_cast<scalar_t>(weight[ii]);
-        }
-      }
-
-      idx++;
-    } while (idx < numel && input[idx] == input[idx - 1]);
-  }
-}
 
 
 Tensor embedding_bag_backward_cuda_sum_avg(
@@ -202,7 +128,7 @@
 
   if (numel == 0) {
     // all empty bags
-    return grad_weight;
+    return at::zeros({num_weights, grad.size(1)}, grad.options());
   }
 
   int64_t stride = grad_weight.stride(0);
@@ -257,24 +183,9 @@
         thrust::make_reverse_iterator(count_data + numel),
         thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
   }
-
-  dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128));
-  dim3 block(32, 4);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      grad.scalar_type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] {
-        EmbeddingBag_accGradParametersKernel_sum_avg<
-            scalar_t><<<grid, block, 0, stream>>>(
-            sorted_indices.data<int64_t>(), orig_indices.data<int64_t>(),
-            grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
-            offset2bag.data<int64_t>(),
-            count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
-            mode, bag_size.data<int64_t>(),
-            per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
-            per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
-      });
-
-  THCudaCheck(cudaGetLastError());
-  return grad_weight;
+  return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
+      count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq,
+      mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights);
 }
 
 template <typename scalar_t>
@@ -297,7 +208,8 @@
       int64_t word_idx = max_indices[bag * stride + featureDim];
       if (word_idx >= 0) {
         // If bag is empty, we have max_indices[idx] set to -1 in forward.
-        atomicAdd(&(gradWeight[word_idx * stride + featureDim]), gradOutput[bag * stride + featureDim]);
+        atomicAdd(&(gradWeight[word_idx * stride + featureDim]),
+                gradOutput[bag * stride + featureDim]);
       }
     }
   }
@@ -411,7 +323,8 @@
     case MODE_MEAN:
       if (mode == MODE_MEAN)
         AT_ASSERT(!per_sample_weights.defined());
-      return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
+      return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
+              bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
 
     case MODE_MAX:
       AT_ASSERT(!per_sample_weights.defined());