Fix: memory cross-border access on the ROCM platform (#76100)
Fixes #76095, memory cross-border access on the ROCM platform.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76100
Approved by: https://github.com/jeffdaily, https://github.com/kit1980
diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu
index 8a241ca..568669e 100644
--- a/aten/src/ATen/native/cuda/Embedding.cu
+++ b/aten/src/ATen/native/cuda/Embedding.cu
@@ -98,10 +98,9 @@
// then finishes by adding the accumulated buffer to dst_row in grad_weight.
if(dst_row != padding_idx && src_row < n) // Per-warp exit condition, safe with ballot_sync
{
- int match_found_this_thread =
- (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]);
- if(threadIdx.x >= n_this_chunk)
- match_found_this_thread = 0;
+ int match_found_this_thread = 0;
+ if(threadIdx.x < n_this_chunk)
+ match_found_this_thread = (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]);
#if defined(USE_ROCM)
unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread);
int first_remaining_peer = __ffsll(matchmask) - 1;