Fix invalid read in masked softmax (#82272)
PEr title, unfortunately testing invalid reads with caching allocator is hard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82272
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
index 4f308d0..9958d4c 100644
--- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
+++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
@@ -123,12 +123,14 @@
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (is_masked) {
int idx = it*WARP_SIZE;
- if (!is_transformer_mask) {
- idx += i*element_count;
- }
- if (!mask[idx]) {
- max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
- is_meaningful_max = true;
+ if ((idx + local_idx) < element_count) {
+ if (!is_transformer_mask) {
+ idx += i*element_count;
+ }
+ if (!mask[idx]) {
+ max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+ is_meaningful_max = true;
+ }
}
} else {
max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
@@ -156,22 +158,28 @@
}
} else {
int idx = it*WARP_SIZE;
+ bool valid = (idx + local_idx) < element_count;
if (!is_transformer_mask) {
idx += i*element_count;
}
-
- if (!mask[idx]) {
- if (is_log_softmax) {
- sum[i] += std::exp(elements[i][it] - max_value[i]);
+ if (valid) {
+ if (!mask[idx]) {
+ if (is_log_softmax) {
+ sum[i] += std::exp(elements[i][it] - max_value[i]);
+ } else {
+ elements[i][it] = std::exp(elements[i][it] - max_value[i]);
+ sum[i] += elements[i][it];
+ }
} else {
- elements[i][it] = std::exp(elements[i][it] - max_value[i]);
- sum[i] += elements[i][it];
+ if (!is_log_softmax) {
+ // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
+ elements[i][it] = 0;
+ }
}
} else {
- if (!is_log_softmax) {
- // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
- elements[i][it] = 0;
- }
+ if (!is_log_softmax) {
+ elements[i][it] = 0.;
+ }
}
}
}