blob: a98c4e6d3daa54547920f2a957015a0ceafb919e [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <c10/macros/Macros.h>
namespace at {
namespace native {
namespace {
// The maximum block size in CUDA
constexpr int MAX_BLOCK_SIZE = 1024;
/* This code computes the sum of the weights in two-steps:
1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
2) Each partial-sum from 1) are summed and scatter into `grad_weight`
Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
kernel execution. If it is high, the size of the thread blocks will be
too small to achieve good occupancy. Similarly, a very low value will
make the size of the thread blocks in the final sum in step 2) too small.
*/
constexpr int NROWS_PER_THREAD = 10;
// Fast ceil division (no overflow checking)
__host__ __device__ __forceinline__
int64_t ceil_div(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
template <typename index_t>
__global__
void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
int64_t num_of_segments, int64_t numel) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
const int64_t idx_start = segment_offsets[id];
const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
const int64_t size = idx_end - idx_start;
ret[id] = ceil_div(size, NROWS_PER_THREAD);
}
}
template <typename index_t>
__global__
void krn_partial_segment_offset(
index_t *ret,
const index_t *partials_per_segment,
const index_t *partials_per_segment_offset,
const index_t *segment_offsets,
int64_t num_of_segments) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
index_t idx = partials_per_segment_offset[id];
const index_t num_partials = partials_per_segment[id];
const index_t segment_offset = segment_offsets[id];
for (int64_t i=0; i<num_partials; ++i) {
ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
}
}
}
template <typename scalar_t, typename index_t>
__global__ void compute_grad_weight_bags(
index_t *indices, scalar_t *gradOutput,
index_t *offset2bag, index_t *count, ptrdiff_t numel,
int64_t stride, int mode_mean, const index_t *bag_size,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
index_t* segment_offsets, int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}
if (id >= num_of_segments) {
return;
}
const int idx_begin = segment_offsets[id];
const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
acc_type<scalar_t, true> weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
const int origRow = indices[idx];
const int seq_number = offset2bag[origRow];
const int gradOutputRow = seq_number * stride;
acc_type<scalar_t, true> scale = count ? 1.0 / count[idx] : 1.0;
if (per_sample_weights) {
scale *= per_sample_weights[origRow * per_sample_weights_stride];
}
acc_type<scalar_t, true> gradient = gradOutput[gradOutputRow + startFeature];
if (mode_mean) {
gradient /= bag_size[seq_number];
}
weight += gradient * scale;
}
grad_weight_per_segment[id * stride + startFeature] = weight;
}
template <typename scalar_t, typename index_t>
__global__ void compute_grad_weight(
index_t *indices,
scalar_t *gradOutput,
index_t *count,
ptrdiff_t numel,
int64_t stride,
index_t* segment_offsets,
int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
using accscalar_t = acc_type<scalar_t, true>;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}
if (id >= num_of_segments) {
return;
}
const int idx_begin = segment_offsets[id];
const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
accscalar_t weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
const index_t target_row = indices[idx];
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
weight += gradOutput[target_row * stride + startFeature] * scale;
}
grad_weight_per_segment[id * stride + startFeature] = weight;
}
// This kernel assumes that all input tensors are contiguous.
template <typename scalar_t, typename index_t>
__global__ void sum_and_scatter(
index_t *input, scalar_t *gradWeight, int64_t stride,
index_t* segment_offsets, int64_t num_of_segments,
const acc_type<scalar_t, true> *grad_weight_per_segment,
const index_t *segment_sizes_offsets, int64_t num_of_partial_segments,
const int64_t padding_idx,
const int64_t stride_warped) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}
if (id >= num_of_segments) {
return;
}
const int idx_begin = segment_sizes_offsets[id];
const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1];
acc_type<scalar_t, true> weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
weight += grad_weight_per_segment[idx*stride + startFeature];
}
int64_t target_row = input[segment_offsets[id]];
if (target_row != padding_idx) {
gradWeight[target_row * stride + startFeature] = weight;
}
}
} // anon namespace
Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
const Tensor &orig_indices,
const Tensor &sorted_indices,
const Tensor &count,
int64_t num_weights,
int padding_idx,
bool mode_mean,
const Tensor &offset2bag,
const Tensor &bag_size,
const Tensor &per_sample_weights) {
auto stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();
auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
const int64_t stride = grad_weight.stride(0);
// Compute the number of segments and their start position so that we do not have to
// spawn a warp per index. In this context, a segment is a number of rows that should
// be summarized.
// Unit: index in `sorted_indices` and `orig_indices`
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
{
auto sorted_indices_dev = thrust::device_ptr<index_t>(sorted_indices.data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.data_ptr<index_t>()));
num_of_segments = thrust::get<0>(ends) - dummy_dev;
}
// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
{
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partials_per_segment.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
num_of_segments,
numel);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// In order to compute `partial_segment_offset`, which is the start index
// of each partial-segment in `sorted_indices`, we need to compute the
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
thrust::exclusive_scan(
policy,
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()),
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()+num_of_segments),
thrust::device_ptr<index_t>(partials_per_segment_offset.data_ptr<index_t>()));
// The total number of partial-segments is the sum of `partials_per_segment_offset`
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
partials_per_segment_offset[num_of_segments-1].item<index_t>();
// Now we can compute the start position of each partial-segment
// Unit: index in `sorted_indices` and `orig_indices`
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
{
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partial_segment_offset.data_ptr<index_t>(),
partials_per_segment.data_ptr<index_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
num_of_segments);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr, numel, stride,
mode_mean, bag_size.data_ptr<index_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<index_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
padding_idx,
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return grad_weight;
}
}}