Remove sync in embedding (#70943)
Summary:
This together with https://github.com/pytorch/pytorch/pull/66580 and https://github.com/pytorch/pytorch/pull/68376 will remove all syncs in embedding.
This PR includes https://github.com/pytorch/pytorch/pull/68376, please review after merging https://github.com/pytorch/pytorch/pull/68376
This PR introduces perf regressions and increases memory usage:
- `exclusive_sum` is now computing the entire `numel` elements instead of `num_of_segments` elements, and the trailing `numel - num_of_segments` results will be discarded.
- Some memory allocation now needs `numel` spaces instead of `num_of_segments` or `num_of_partial_segments`.
These are the prices we must pay in order to get a sync-free implementation.
I haven't done any benchmark yet. I will do it later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70943
Reviewed By: H-Huang
Differential Revision: D34881660
Pulled By: ngimel
fbshipit-source-id: b0760fa33608c46cd4145ceb09878bf94a9f959d
(cherry picked from commit d959fa4783cfee84bf17c1fa6d0f5d6bde268d75)
diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
index 6035f30..202f3ad 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
@@ -39,7 +39,8 @@
template <typename index_t>
__global__
void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
- int64_t num_of_segments, int64_t numel) {
+ int64_t *num_of_segments_ptr, int64_t numel) {
+ int64_t num_of_segments = *num_of_segments_ptr;
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
const int64_t idx_start = segment_offsets[id];
@@ -56,7 +57,8 @@
const index_t *partials_per_segment,
const index_t *partials_per_segment_offset,
const index_t *segment_offsets,
- int64_t num_of_segments) {
+ int64_t *num_of_segments_ptr) {
+ int64_t num_of_segments = *num_of_segments_ptr;
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
index_t idx = partials_per_segment_offset[id];
@@ -75,10 +77,11 @@
index_t *offset2bag, index_t *count, ptrdiff_t numel,
int64_t stride, int mode_mean, const index_t *bag_size,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
- index_t* segment_offsets, int64_t num_of_segments,
+ index_t* segment_offsets, int64_t *num_of_segments_ptr,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
+ int64_t num_of_segments = *num_of_segments_ptr;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
@@ -119,10 +122,11 @@
ptrdiff_t numel,
int64_t stride,
index_t* segment_offsets,
- int64_t num_of_segments,
+ int64_t *num_of_segments_ptr,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
+ int64_t num_of_segments = *num_of_segments_ptr;
using accscalar_t = acc_type<scalar_t, true>;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
@@ -149,12 +153,14 @@
template <typename scalar_t, typename index_t>
__global__ void sum_and_scatter(
index_t *input, scalar_t *gradWeight, int64_t stride,
- index_t* segment_offsets, int64_t num_of_segments,
+ index_t* segment_offsets, int64_t *num_of_segments_ptr,
const acc_type<scalar_t, true> *grad_weight_per_segment,
- const index_t *segment_sizes_offsets, int64_t num_of_partial_segments,
+ const index_t *segment_sizes_offsets, int64_t *num_of_partial_segments_ptr,
const int64_t padding_idx,
const int64_t stride_warped) {
+ int64_t num_of_segments = *num_of_segments_ptr;
+ int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
@@ -177,6 +183,17 @@
}
}
+template<typename index_t>
+__global__ void compute_num_of_partial_segments(index_t *partials_per_segment, index_t *partials_per_segment_offset, int64_t *num_of_segments_ptr, int64_t *output) {
+ int64_t num_of_segments = *num_of_segments_ptr;
+ *output = partials_per_segment[num_of_segments-1] +
+ partials_per_segment_offset[num_of_segments-1];
+}
+
+__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
+ *num_of_segments_ptr = num_of_segments;
+}
+
} // anon namespace
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
@@ -207,10 +224,13 @@
// be summarized.
// Unit: index in `sorted_indices` and `orig_indices`
auto segment_offsets = at::empty({numel}, orig_indices.options());
- int64_t num_of_segments;
+ auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
+ int64_t *num_of_segments_ptr = num_of_segments_tensor.data_ptr<int64_t>();
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
- num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
+ int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
+ write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#else
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
@@ -218,8 +238,7 @@
cuda::cub::unique_by_key(
sorted_indices.data_ptr<index_t>(), thrust::make_counting_iterator(0),
nullptr, segment_offsets.data_ptr<index_t>(),
- num_of_segments_tensor.data_ptr<int64_t>(), sorted_indices.numel());
- num_of_segments = num_of_segments_tensor.item<int64_t>();
+ num_of_segments_ptr, sorted_indices.numel());
});
#endif
@@ -227,12 +246,12 @@
// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
- auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
+ auto partials_per_segment = at::empty({numel}, orig_indices.options());
{
- krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+ krn_partials_per_segment<<<ceil_div(numel, 32), 32, 0, stream>>> (
partials_per_segment.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
- num_of_segments,
+ num_of_segments_ptr,
numel);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
@@ -241,33 +260,38 @@
// of each partial-segment in `sorted_indices`, we need to compute the
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
- auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
+ auto partials_per_segment_offset = at::empty({numel}, orig_indices.options());
cuda::cub::exclusive_sum(
partials_per_segment.data_ptr<index_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
- num_of_segments);
+ numel);
// The total number of partial-segments is the sum of `partials_per_segment_offset`
- const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
- partials_per_segment_offset[num_of_segments-1].item<index_t>();
+ auto num_of_partial_segments_tensor = at::empty({}, grad.options().dtype(kLong));
+ int64_t *num_of_partial_segments_ptr = num_of_partial_segments_tensor.data_ptr<int64_t>();
+ compute_num_of_partial_segments<index_t><<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(
+ partials_per_segment.data_ptr<index_t>(),
+ partials_per_segment_offset.data_ptr<index_t>(),
+ num_of_segments_ptr, num_of_partial_segments_ptr);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
// Now we can compute the start position of each partial-segment
// Unit: index in `sorted_indices` and `orig_indices`
- auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
+ auto partial_segment_offset = at::empty({numel}, orig_indices.options());
{
- krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
+ krn_partial_segment_offset<<<ceil_div(numel, 32), 32, 0, stream>>> (
partial_segment_offset.data_ptr<index_t>(),
partials_per_segment.data_ptr<index_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
- num_of_segments);
+ num_of_segments_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
const int warp_size = at::cuda::warp_size();
const int stride_warped = ceil_div(stride, warp_size)*warp_size;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
- const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
+ const int grid = ceil_div(numel*stride_warped, block);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
@@ -280,7 +304,7 @@
} else {
op = grad.options();
}
- auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
+ auto grad_weight_per_segment = at::empty({numel, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
@@ -292,7 +316,7 @@
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<index_t>(),
- num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
+ num_of_partial_segments_ptr, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
@@ -302,7 +326,7 @@
count.defined() ? count.data_ptr<index_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<index_t>(),
- num_of_partial_segments,
+ num_of_partial_segments_ptr,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -310,15 +334,15 @@
// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
- const int grid2 = ceil_div(num_of_segments*stride_warped, block);
+ const int grid2 = ceil_div(numel*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<index_t>(),
- num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
+ num_of_segments_ptr, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
- num_of_partial_segments,
+ num_of_partial_segments_ptr,
padding_idx,
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();