Optimization of the Embedding and Embedding-Bag CUDA Kernel (#22016)
Summary:
Re-implementation of the `embedding_dense_backward_cuda()` and the `embedding_bag_backward_cuda_sum_avg()` functions.
#### Performance
Running a [Mortgage Workflow](https://github.com/EvenOldridge/MortgageWorkflowA) with a block size of 100K on a DXG-2 (single GPU), we see a 270% speedup:
```
Original version: 370,168 example/s
Optimized version: 1,034,228 example/s
```
The original version is bounded by the `EmbeddingBag_accGradParametersKernel_sum_avg`, which takes 70% of the CUDA execution time. In the optimized version, the optimized kernel now takes only 17% of the time.
#### Greater Numerical Stability
An added benefit is greater numerical stability. Instead of doing a flat sum where a single variable are used to accumulate the weights, this code uses two-steps where each GPU-thread computes a sub-result defined by `NROWS_PER_THREAD` before the final result are accumulated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22016
Differential Revision: D15944339
Pulled By: mrshenli
fbshipit-source-id: 398d5f48826a017fc4b31c24c3f8b56d01830bf0
diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu
index e7a76ad..1bee496 100644
--- a/aten/src/ATen/native/cuda/Embedding.cu
+++ b/aten/src/ATen/native/cuda/Embedding.cu
@@ -12,6 +12,8 @@
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
+#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
+
namespace at { namespace native {
@@ -231,14 +233,12 @@
auto num_indices = indices.numel();
auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
- auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
-
- int64_t stride = grad_weight.stride(0);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_indices <= 768 && !scale_grad_by_freq) {
auto indices_contig = indices.contiguous();
-
+ auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
+ int64_t stride = grad_weight.stride(0);
dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE));
dim3 block(WARP_SIZE, BLOCKDIMY);
@@ -323,23 +323,8 @@
);
}
- dim3 grid(THCCeilDiv(num_indices, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
- dim3 block(WARP_SIZE, 4);
-
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "embedding_backward", [&] {
- embedding_backward_kernel<<<grid, block, 0, stream>>>(
- sorted_indices.data<int64_t>(),
- orig_indices.data<int64_t>(),
- grad.data<scalar_t>(),
- grad_weight.data<scalar_t>(),
- count.defined() ? count.data<int64_t>() : nullptr,
- num_indices,
- stride,
- padding_idx);
- });
- THCudaCheck(cudaGetLastError());
-
- return grad_weight;
+ return embedding_backward_cuda_kernel(grad, orig_indices,
+ sorted_indices, count, num_weights, padding_idx);
}
Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
new file mode 100644
index 0000000..5924350
--- /dev/null
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
@@ -0,0 +1,326 @@
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/AccumulateType.h>
+
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCTensorSort.cuh>
+#include <THC/THCThrustAllocator.cuh>
+#include <THC/THCAtomics.cuh>
+
+#include <thrust/execution_policy.h>
+#include <thrust/unique.h>
+#include <thrust/device_vector.h>
+
+namespace at {
+namespace native {
+
+namespace {
+
+// The maximum block size in CUDA
+constexpr int MAX_BLOCK_SIZE = 1024;
+/* This code computes the sum of the weights in two-steps:
+ 1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
+ 2) Each partial-sum from 1) are summed and scatter into `grad_weight`
+
+ Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
+ kernel execution. If it is high, the size of the thread blocks will be
+ too small to achieve good occupancy. Similarly, a very low value will
+ make the size of the thread blocks in the final sum in step 2) too small.
+*/
+constexpr int NROWS_PER_THREAD = 10;
+
+#ifdef __HIP_PLATFORM_HCC__
+ constexpr int WARP_SIZE = 64;
+#else
+ constexpr int WARP_SIZE = 32;
+#endif
+
+// Fast ceil division (no overflow checking)
+__host__ __device__ __forceinline__
+int64_t ceil_div(int64_t x, int64_t y) {
+ return (x + y - 1) / y;
+}
+
+__global__
+void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
+ int64_t num_of_segments, int64_t numel) {
+ const int id = blockIdx.x * blockDim.x + threadIdx.x;
+ if(id < num_of_segments) {
+ const int64_t idx_start = segment_offsets[id];
+ const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+ const int64_t size = idx_end - idx_start;
+ ret[id] = ceil_div(size, NROWS_PER_THREAD);
+ }
+}
+
+__global__
+void krn_partial_segment_offset(
+ int64_t *ret,
+ const int64_t *partials_per_segment,
+ const int64_t *partials_per_segment_offset,
+ const int64_t *segment_offsets,
+ int64_t num_of_segments) {
+ const int id = blockIdx.x * blockDim.x + threadIdx.x;
+ if(id < num_of_segments) {
+ int64_t idx = partials_per_segment_offset[id];
+ const int64_t num_partials = partials_per_segment[id];
+ const int64_t segment_offset = segment_offsets[id];
+ for (int64_t i=0; i<num_partials; ++i) {
+ ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
+ }
+ }
+}
+
+
+template <typename scalar_t>
+__global__ void compute_grad_weight_bags(
+ int64_t *indices, scalar_t *gradOutput,
+ int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
+ int64_t stride, int mode_mean, const int64_t *bag_size,
+ scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
+ int64_t* segment_offsets, int64_t num_of_segments, scalar_t *grad_weight_per_segment,
+ const int64_t stride_warped) {
+
+ const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+ const int id = gid / stride_warped;
+ const int startFeature = gid % stride_warped;
+ if (startFeature >= stride) {
+ return;
+ }
+ if (id >= num_of_segments) {
+ return;
+ }
+ const int idx_begin = segment_offsets[id];
+ const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+
+ acc_type<scalar_t, true> weight = 0;
+ for (int idx=idx_begin; idx < idx_end; ++idx) {
+ const int origRow = indices[idx];
+ const int seq_number = offset2bag[origRow];
+ const int gradOutputRow = seq_number * stride;
+
+ acc_type<scalar_t, true> scale = count ? 1.0 / count[idx] : 1.0;
+ if (per_sample_weights) {
+ scale *= per_sample_weights[origRow * per_sample_weights_stride];
+ }
+
+ acc_type<scalar_t, true> gradient = gradOutput[gradOutputRow + startFeature];
+ if (mode_mean) {
+ gradient /= bag_size[seq_number];
+ }
+ weight += gradient * scale;
+ }
+ grad_weight_per_segment[id * stride + startFeature] = weight;
+}
+
+template <typename scalar_t>
+__global__ void compute_grad_weight(
+ int64_t *indices,
+ scalar_t *gradOutput,
+ int64_t *count,
+ ptrdiff_t numel,
+ int64_t stride,
+ int64_t* segment_offsets,
+ int64_t num_of_segments,
+ scalar_t *grad_weight_per_segment,
+ int padding_idx,
+ const int64_t stride_warped) {
+
+ using accscalar_t = acc_type<scalar_t, true>;
+ const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+ const int id = gid / stride_warped;
+ const int startFeature = gid % stride_warped;
+ if (startFeature >= stride) {
+ return;
+ }
+ if (id >= num_of_segments) {
+ return;
+ }
+ const int idx_begin = segment_offsets[id];
+ const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
+ if (idx_begin == padding_idx) {
+ return;
+ }
+
+ accscalar_t weight = 0;
+ for (int idx=idx_begin; idx < idx_end; ++idx) {
+ const int64_t target_row = indices[idx];
+ if (target_row != padding_idx) {
+ const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+ weight += gradOutput[target_row * stride + startFeature] * scale;
+ }
+ }
+ grad_weight_per_segment[id * stride + startFeature] = weight;
+}
+
+// This kernel assumes that all input tensors are contiguous.
+template <typename scalar_t>
+__global__ void sum_and_scatter(
+ int64_t *input, scalar_t *gradWeight, int64_t stride,
+ int64_t* segment_offsets, int64_t num_of_segments, const scalar_t *grad_weight_per_segment,
+ const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
+ const int64_t stride_warped) {
+
+ const int gid = blockIdx.x * blockDim.x + threadIdx.x;
+ const int id = gid / stride_warped;
+ const int startFeature = gid % stride_warped;
+ if (startFeature >= stride) {
+ return;
+ }
+ if (id >= num_of_segments) {
+ return;
+ }
+
+ const int idx_begin = segment_sizes_offsets[id];
+ const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1];
+ acc_type<scalar_t, true> weight = 0;
+ for (int idx=idx_begin; idx < idx_end; ++idx) {
+ weight += grad_weight_per_segment[idx*stride + startFeature];
+ }
+ const int weightRow = input[segment_offsets[id]] * stride;
+ gradWeight[weightRow + startFeature] = weight;
+}
+
+} // anon namespace
+
+Tensor embedding_backward_cuda_kernel(
+ const Tensor &grad,
+ const Tensor &orig_indices,
+ const Tensor &sorted_indices,
+ const Tensor &count,
+ int64_t num_weights,
+ int padding_idx,
+ bool scale_grad_by_freq,
+ bool mode_mean,
+ const Tensor &offset2bag,
+ const Tensor &bag_size,
+ const Tensor &per_sample_weights) {
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+ auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+ auto policy = thrust::cuda::par(allocator).on(stream);
+ const ptrdiff_t numel = sorted_indices.numel();
+
+ auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
+ const int64_t stride = grad_weight.stride(0);
+
+ // Compute the number of segments and their start position so that we do not have to
+ // spawn a warp per index. In this context, a segment is a number of rows that should
+ // be summarized.
+ // Unit: index in `sorted_indices` and `orig_indices`
+ thrust::device_vector<int64_t> segment_offsets(numel);
+ int64_t num_of_segments;
+ {
+ auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data<int64_t>());
+ auto dummy = at::empty_like(sorted_indices);
+ auto dummy_dev = thrust::device_ptr<int64_t>(dummy.data<int64_t>());
+ auto ends = thrust::unique_by_key_copy(
+ policy,
+ sorted_indices_dev,
+ sorted_indices_dev + numel,
+ thrust::make_counting_iterator(0),
+ dummy_dev,
+ thrust::raw_pointer_cast(segment_offsets.data()));
+ num_of_segments = thrust::get<0>(ends) - dummy_dev;
+ }
+
+ // We split the segments up into sizes of `NROWS_PER_THREAD`
+ // Compute the number partial-segments per segment (some partial-segments
+ // may not be the full `NROWS_PER_THREAD` number of rows)
+ thrust::device_vector<int64_t> partials_per_segment(num_of_segments);
+ {
+ krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+ thrust::raw_pointer_cast(partials_per_segment.data()),
+ thrust::raw_pointer_cast(segment_offsets.data()),
+ num_of_segments,
+ numel);
+ }
+
+ // In order to compute `partial_segment_offset`, which is the start index
+ // of each partial-segment in `sorted_indices`, we need to compute the
+ // start position of each _segment_ in `partial_segment_offset`.
+ // Unit: index in `partial_segment_offset`
+ thrust::device_vector<int64_t> partials_per_segment_offset(num_of_segments);
+ thrust::exclusive_scan(
+ policy,
+ partials_per_segment.begin(),
+ partials_per_segment.end(),
+ partials_per_segment_offset.begin());
+
+ // The total number of partial-segments is the sum of `partials_per_segment_offset`
+ const int num_of_partial_segments = partials_per_segment[num_of_segments-1] +
+ partials_per_segment_offset[num_of_segments-1];
+
+ // Now we can compute the start position of each partial-segment
+ // Unit: index in `sorted_indices` and `orig_indices`
+ thrust::device_vector<int64_t> partial_segment_offset(num_of_partial_segments);
+ {
+ krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+ thrust::raw_pointer_cast(partial_segment_offset.data()),
+ thrust::raw_pointer_cast(partials_per_segment.data()),
+ thrust::raw_pointer_cast(partials_per_segment_offset.data()),
+ thrust::raw_pointer_cast(segment_offsets.data()),
+ num_of_segments);
+ }
+
+ auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, grad.options());
+ const int stride_warped = ceil_div(stride, WARP_SIZE)*WARP_SIZE;
+ const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
+ const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
+
+ // Compute the sum of each partial-segment and handle bags
+ if (offset2bag.defined()) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
+ compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
+ orig_indices.data<int64_t>(),
+ grad.data<scalar_t>(),
+ offset2bag.data<int64_t>(),
+ count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
+ mode_mean, bag_size.data<int64_t>(),
+ per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+ per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
+ thrust::raw_pointer_cast(partial_segment_offset.data()),
+ num_of_partial_segments, grad_weight_per_segment.data<scalar_t>(),
+ stride_warped);
+ });
+ } else {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
+ compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
+ orig_indices.data<int64_t>(),
+ grad.data<scalar_t>(),
+ count.defined() ? count.data<int64_t>() : nullptr,
+ numel, stride,
+ thrust::raw_pointer_cast(partial_segment_offset.data()),
+ num_of_partial_segments,
+ grad_weight_per_segment.data<scalar_t>(),
+ padding_idx,
+ stride_warped);
+ });
+ }
+ THCudaCheck(cudaGetLastError());
+
+ // Finally, we sum all the partial-sums and scatter them
+ // into `grad_weight`.
+ const int grid2 = ceil_div(num_of_segments*stride_warped, block);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "embedding_bag_backward_cuda_sum_and_scatter", [&] {
+ sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
+ sorted_indices.data<int64_t>(),
+ grad_weight.data<scalar_t>(),
+ stride,
+ thrust::raw_pointer_cast(segment_offsets.data()),
+ num_of_segments, grad_weight_per_segment.data<scalar_t>(),
+ thrust::raw_pointer_cast(partials_per_segment_offset.data()),
+ num_of_partial_segments, stride_warped);
+ });
+ THCudaCheck(cudaGetLastError());
+ return grad_weight;
+}
+
+}}
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
new file mode 100644
index 0000000..e483550
--- /dev/null
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
@@ -0,0 +1,36 @@
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/AccumulateType.h>
+
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCTensorSort.cuh>
+#include <THC/THCThrustAllocator.cuh>
+#include <THC/THCAtomics.cuh>
+
+#include <thrust/execution_policy.h>
+#include <thrust/unique.h>
+#include <thrust/device_vector.h>
+
+#pragma once
+
+namespace at {
+namespace native {
+
+Tensor embedding_backward_cuda_kernel(
+ const Tensor &grad,
+ const Tensor &orig_indices,
+ const Tensor &sorted_indices,
+ const Tensor &count,
+ int64_t num_weights,
+ int padding_idx = -1,
+ bool scale_grad_by_freq = false,
+ bool mode_mean = false,
+ const Tensor &offset2bag = Tensor(),
+ const Tensor &bag_size = Tensor(),
+ const Tensor &per_sample_weights = Tensor());
+
+}}
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index f0761bd..3f34d80 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -13,17 +13,22 @@
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
+#include <thrust/device_vector.h>
-const int WARP_SIZE = 32;
-const int MODE_SUM = 0;
-const int MODE_MEAN = 1;
-const int MODE_MAX = 2;
+#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
namespace at {
namespace native {
namespace {
+constexpr int MODE_SUM = 0;
+constexpr int MODE_MEAN = 1;
+constexpr int MODE_MAX = 2;
+
+constexpr int WARP_SIZE = 32;
+
+
// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t>
@@ -104,85 +109,6 @@
}
}
-// FIXME: removed the accGradParametersKernelByFeature case present in
-// LookupTable. That kernel is faster at small sizes (<768 indices), which
-// does not need EmbeddingBag (LookupTable + Sum works fine), but would
-// still be nice to not be slow in that case.
-
-// This kernel assumes that all input tensors are contiguous.
-template <typename scalar_t>
-__global__ void EmbeddingBag_accGradParametersKernel_sum_avg(
- int64_t *input, int64_t *indices, scalar_t *gradOutput,
- scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
- int64_t stride, int mode, const int64_t *bag_size,
- scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
-
- using accscalar_t = acc_type<scalar_t, true>;
- int idx = blockIdx.x * 4 + threadIdx.y;
-
- // Each warp is responsible for an input into the LookupTable.
- // If the preceding input has the same as this input, then the warp
- // exits immediately. The warp also processes subsequent inputs with the
- // same value. //
- // Input Warp
- // 1 <warp 1>
- // 1 <warp 1> (<warp 2> exits without doing any work)
- // 5 <warp 3>
- // 8 <warp 4>
-
- // Number of values proceessed by each thread (grain size)
- const int SZ = 4;
-
- if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
- do {
- const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
- const int weightRow = ((int)input[idx]) * stride;
-
- // Note: only this line changes from LookupTable_accgradParametersKernel
- const int origRow = ((int)indices[idx]);
- const int seq_number = offset2bag[origRow];
- const int gradOutputRow = ((int)seq_number) * stride;
-
- accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
- if (per_sample_weights) {
- scale *= per_sample_weights[origRow * per_sample_weights_stride];
- }
-
- accscalar_t gradient[SZ];
- accscalar_t weight[SZ];
-
-#pragma unroll
- for (int ii = 0; ii < SZ; ii++) {
- int featureDim = startFeature + ii * WARP_SIZE;
- if (featureDim < stride) {
- gradient[ii] =
- static_cast<accscalar_t>(gradOutput[gradOutputRow + featureDim]);
- if (mode == MODE_MEAN) {
- gradient[ii] /= bag_size[seq_number];
- }
- weight[ii] =
- static_cast<accscalar_t>(gradWeight[weightRow + featureDim]);
- }
- }
-
-#pragma unroll
- for (int ii = 0; ii < SZ; ii++) {
- weight[ii] += gradient[ii] * scale;
- }
-
-#pragma unroll
- for (int ii = 0; ii < SZ; ii++) {
- int featureDim = startFeature + ii * WARP_SIZE;
- if (featureDim < stride) {
- gradWeight[weightRow + featureDim] =
- static_cast<scalar_t>(weight[ii]);
- }
- }
-
- idx++;
- } while (idx < numel && input[idx] == input[idx - 1]);
- }
-}
Tensor embedding_bag_backward_cuda_sum_avg(
@@ -202,7 +128,7 @@
if (numel == 0) {
// all empty bags
- return grad_weight;
+ return at::zeros({num_weights, grad.size(1)}, grad.options());
}
int64_t stride = grad_weight.stride(0);
@@ -257,24 +183,9 @@
thrust::make_reverse_iterator(count_data + numel),
thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
}
-
- dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128));
- dim3 block(32, 4);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
- grad.scalar_type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] {
- EmbeddingBag_accGradParametersKernel_sum_avg<
- scalar_t><<<grid, block, 0, stream>>>(
- sorted_indices.data<int64_t>(), orig_indices.data<int64_t>(),
- grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
- offset2bag.data<int64_t>(),
- count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
- mode, bag_size.data<int64_t>(),
- per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
- per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
- });
-
- THCudaCheck(cudaGetLastError());
- return grad_weight;
+ return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
+ count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq,
+ mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights);
}
template <typename scalar_t>
@@ -297,7 +208,8 @@
int64_t word_idx = max_indices[bag * stride + featureDim];
if (word_idx >= 0) {
// If bag is empty, we have max_indices[idx] set to -1 in forward.
- atomicAdd(&(gradWeight[word_idx * stride + featureDim]), gradOutput[bag * stride + featureDim]);
+ atomicAdd(&(gradWeight[word_idx * stride + featureDim]),
+ gradOutput[bag * stride + featureDim]);
}
}
}
@@ -411,7 +323,8 @@
case MODE_MEAN:
if (mode == MODE_MEAN)
AT_ASSERT(!per_sample_weights.defined());
- return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
+ return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
+ bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
case MODE_MAX:
AT_ASSERT(!per_sample_weights.defined());