Update FlexAttention with masking semantic (#133373)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373
Approved by: https://github.com/yanboliang
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 428aca1..82e404f 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -17,6 +17,7 @@
     _create_empty_block_mask,
     _DEFAULT_SPARSE_BLOCK_SIZE,
     _identity,
+    _score_mod_signature,
     and_masks,
     BlockMask,
     create_block_mask,
@@ -212,8 +213,7 @@
     ):
         compiled_error = (golden_out - compiled_out).abs().mean()
         ref_error = (golden_out - ref_out).abs().mean()
-        # TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
-        if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
+        if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
             self.assertTrue(False, "Output/Grad with NaN")
         if compiled_error > ref_error * fudge_factor:
             name = tensor_name if tensor_name is not None else ""
@@ -263,7 +263,7 @@
 
     def run_test(
         self,
-        score_mod: Callable,
+        score_mod: _score_mod_signature,
         dtype: torch.dtype = torch.float16,
         Q_B: int = B,
         Q_H: int = H,
@@ -273,6 +273,7 @@
         KV_H: int = H,
         KV_S: int = S,
         KV_D: int = D,
+        block_mask: Optional[BlockMask] = None,
     ):
         q = torch.randn(
             (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
@@ -285,7 +286,6 @@
         )
         q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
         q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
-        block_mask = None
         sdpa_partial = create_attention(
             score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
         )
@@ -1437,7 +1437,8 @@
         out.sum().backward()
 
     @supported_platform
-    def test_fully_masked_out_rows(self):
+    @common_utils.parametrize("compile", [True, False])
+    def test_fully_masked_out_rows_0_check(self, compile: bool):
         # Ensure fully masked out rows won't cause NaNs.
         query = torch.randn(
             (B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
@@ -1448,7 +1449,6 @@
         value = torch.randn(
             (B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
         )
-        do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
 
         M = S // 2
 
@@ -1456,16 +1456,34 @@
             return q < M
 
         block_mask = create_block_mask(mask_mod, 1, 1, S, S)
-        out = torch.compile(flex_attention, dynamic=False)(
-            query, key, value, block_mask=block_mask
-        )
-        # TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
-        self.assertEqual(out[:, :, M:, :].sum(), 0)
 
-        out.backward(do)
+        flex = (
+            torch.compile(flex_attention, dynamic=False) if compile else flex_attention
+        )
+        out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True)
+        self.assertEqual(out[:, :, M:, :].sum(), 0)
+        self.assertTrue((lse[:, :, M:] == 0.0).all())
+
+        loss = out.sum() + lse.sum()
+        loss.backward()
         self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
 
     @supported_platform
+    @common_utils.parametrize("compile", [True, False])
+    def test_fully_masked_out_rows(self, compile: bool):
+        M = S // 2
+
+        def mask_mod(b, h, q, kv):
+            return q < M
+
+        block_mask = create_block_mask(mask_mod, 1, 1, S, S)
+
+        def noop_mod(score, b, h, q_idx, kv_idx):
+            return score
+
+        self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask)
+
+    @supported_platform
     def test_comparison_vs_sdpa(self):
         def causal(score, b, h, q_idx, kv_idx):
             return torch.where(q_idx >= kv_idx, score, -float("inf"))
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
index 7832442..9a51a95 100644
--- a/test/inductor/test_flex_decoding.py
+++ b/test/inductor/test_flex_decoding.py
@@ -284,15 +284,20 @@
             score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
         )
         compiled_sdpa = torch.compile(sdpa_partial)
-        golden_out = sdpa_partial(q_gold, k_gold, v_gold)
-        ref_out = sdpa_partial(q_ref, k_ref, v_ref)
-        compiled_out = compiled_sdpa(q, k, v)
+        golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
+        ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
+        compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
 
         self._check_out(
             golden_out,
             ref_out,
             compiled_out,
         )
+        self._check_out(
+            gold_lse,
+            ref_lse,
+            compiled_lse,
+        )
 
     def run_test_with_call(
         self,
@@ -763,6 +768,38 @@
         self.run_test(bias_mod)
 
     @supported_platform
+    def test_fully_masked_out_rows_0_check_gqa(self):
+        # Ensure fully masked out rows won't cause NaNs.
+        query = torch.randn(
+            (B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+        )
+        key = torch.randn(
+            (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+        )
+        value = torch.randn(
+            (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+        )
+
+        M = S // 2
+
+        def mask_mod(b, h, q, kv):
+            return q < M
+
+        block_mask = create_block_mask(mask_mod, 1, 1, S, S)
+
+        flex = torch.compile(flex_attention, dynamic=False)
+
+        out, lse = flex(
+            query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
+        )
+        self.assertEqual(out[:, :, M:, :].sum(), 0)
+        self.assertTrue((lse[:, :, M:] == 0.0).all())
+
+        loss = out.sum() + lse.sum()
+        loss.backward()
+        self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
+
+    @supported_platform
     def test_windowed_no_mask_vs_sdpa(self):
         score_mod = _generate_windowed(1000)
         attention = functools.partial(flex_attention, score_mod=score_mod)
diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py
index 4d76e13..b651d7d 100644
--- a/torch/_higher_order_ops/flex_attention.py
+++ b/torch/_higher_order_ops/flex_attention.py
@@ -204,11 +204,12 @@
         mask_mod_other_buffers,
     )
 
-    # TODO Unconditionally return logsumexp for backwards
-    # if any(t.requires_grad for t in (query, key, value)):
+    # Set fully masked rows' sumexp to 0.0
     logsumexp = post_mod_scores.logsumexp(dim=-1)
+    masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
+    logsumexp = torch.where(masked_rows, 0.0, logsumexp)
 
-    post_mod_scores = post_mod_scores.softmax(dim=-1)
+    post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
 
     return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
 
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index 8a12860..05e9f1c 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -302,8 +302,13 @@
         )
 
 
-    # Store output and logsumexp
-    l_i = tl.where(l_i == 0, 1, l_i)
+    # [Note] Handle fully masked out rows:
+    # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
+    # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
+    l_i = tl.where(l_i == 0.0, 1, l_i)
+    masked_out_rows = (m_i == float("-inf"))
+    m_i = tl.where(masked_out_rows, 0, m_i)
+
     acc = acc / l_i[:, None]
     idx_z = tl.program_id(1) // HQ
     idx_hq = tl.program_id(1) % HQ
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
index 249b9b8..96e1dbc 100644
--- a/torch/_inductor/kernel/flex_decoding.py
+++ b/torch/_inductor/kernel/flex_decoding.py
@@ -524,11 +524,17 @@
     # Reduction
 
     g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
+    # See [Note] Handle fully masked out rows:
+    # g_M Is the global max among split kv blocks.
+    masked_rows = lowerings[aten.eq](g_M, -float("inf"))
+    g_M = lowerings[aten.where](masked_rows, 0.0, g_M)
     adj_M = lowerings[aten.sub](buf_M, g_M)
     alpha = lowerings[aten.exp2](adj_M)
 
     buf_L = lowerings[aten.mul](buf_L, alpha)
     g_L = lowerings[aten.sum](buf_L, axis=1)
+    masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
+    g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
     logsumexp = lowerings[aten.log2](g_L)
     logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))