Fix invalid read in masked softmax (#82272)

PEr title, unfortunately testing invalid reads with caching allocator is hard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82272
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
index 4f308d0..9958d4c 100644
--- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
+++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh
@@ -123,12 +123,14 @@
         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
             if (is_masked) {
                 int idx = it*WARP_SIZE;
-                if (!is_transformer_mask) {
-                    idx += i*element_count;
-                }
-                if (!mask[idx]) {
-                    max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
-                    is_meaningful_max = true;
+                if ((idx + local_idx) < element_count) {
+                    if (!is_transformer_mask) {
+                        idx += i*element_count;
+                    }
+                    if (!mask[idx]) {
+                        max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+                        is_meaningful_max = true;
+                    }
                 }
             } else {
                 max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
@@ -156,22 +158,28 @@
                 }
             } else {
                 int idx = it*WARP_SIZE;
+                bool valid = (idx + local_idx) < element_count;
                 if (!is_transformer_mask) {
                     idx += i*element_count;
                 }
-
-                if (!mask[idx]) {
-                    if (is_log_softmax) {
-                        sum[i] += std::exp(elements[i][it] - max_value[i]);
+                if (valid) {
+                    if (!mask[idx]) {
+                        if (is_log_softmax) {
+                            sum[i] += std::exp(elements[i][it] - max_value[i]);
+                        } else {
+                            elements[i][it] = std::exp(elements[i][it] - max_value[i]);
+                            sum[i] += elements[i][it];
+                        }
                     } else {
-                        elements[i][it] = std::exp(elements[i][it] - max_value[i]);
-                        sum[i] += elements[i][it];
+                        if (!is_log_softmax) {
+                            // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
+                            elements[i][it] = 0;
+                        }
                     }
                 } else {
-                  if (!is_log_softmax) {
-                    // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
-                    elements[i][it] = 0;
-                  }
+                    if (!is_log_softmax) {
+                        elements[i][it] = 0.;
+                    }
                 }
             }
         }