Prep PR for cutlass 3.5 update (#124412)
# Summary
These changes are needed for the upgrade to cutlass 3.5
#123458
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124412
Approved by: https://github.com/Skylion007, https://github.com/nWEIdia, https://github.com/malfet
diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h
index 55f3f9a..564e3f2 100644
--- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h
+++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h
@@ -1429,7 +1429,7 @@
uint8_t lane_id) {
cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
dropout_keep_mask_doivj;
- dropout_keep_mask_doivj.fill(1);
+ dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
const float dropout_scale =
kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
@@ -1752,7 +1752,7 @@
[&](int accum_m) {},
[&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
- dropout_keep_mask_doivj[idx] = cutlass::uint1b_t(0);
+ dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
}
},
[&](int accum_m) {});