Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373
Approved by: https://github.com/yanboliang
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 428aca1..82e404f 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -17,6 +17,7 @@
_create_empty_block_mask,
_DEFAULT_SPARSE_BLOCK_SIZE,
_identity,
+ _score_mod_signature,
and_masks,
BlockMask,
create_block_mask,
@@ -212,8 +213,7 @@
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
- # TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
- if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
+ if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
self.assertTrue(False, "Output/Grad with NaN")
if compiled_error > ref_error * fudge_factor:
name = tensor_name if tensor_name is not None else ""
@@ -263,7 +263,7 @@
def run_test(
self,
- score_mod: Callable,
+ score_mod: _score_mod_signature,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = H,
@@ -273,6 +273,7 @@
KV_H: int = H,
KV_S: int = S,
KV_D: int = D,
+ block_mask: Optional[BlockMask] = None,
):
q = torch.randn(
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
@@ -285,7 +286,6 @@
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
- block_mask = None
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
@@ -1437,7 +1437,8 @@
out.sum().backward()
@supported_platform
- def test_fully_masked_out_rows(self):
+ @common_utils.parametrize("compile", [True, False])
+ def test_fully_masked_out_rows_0_check(self, compile: bool):
# Ensure fully masked out rows won't cause NaNs.
query = torch.randn(
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
@@ -1448,7 +1449,6 @@
value = torch.randn(
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
)
- do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
M = S // 2
@@ -1456,16 +1456,34 @@
return q < M
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
- out = torch.compile(flex_attention, dynamic=False)(
- query, key, value, block_mask=block_mask
- )
- # TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
- self.assertEqual(out[:, :, M:, :].sum(), 0)
- out.backward(do)
+ flex = (
+ torch.compile(flex_attention, dynamic=False) if compile else flex_attention
+ )
+ out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True)
+ self.assertEqual(out[:, :, M:, :].sum(), 0)
+ self.assertTrue((lse[:, :, M:] == 0.0).all())
+
+ loss = out.sum() + lse.sum()
+ loss.backward()
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
@supported_platform
+ @common_utils.parametrize("compile", [True, False])
+ def test_fully_masked_out_rows(self, compile: bool):
+ M = S // 2
+
+ def mask_mod(b, h, q, kv):
+ return q < M
+
+ block_mask = create_block_mask(mask_mod, 1, 1, S, S)
+
+ def noop_mod(score, b, h, q_idx, kv_idx):
+ return score
+
+ self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask)
+
+ @supported_platform
def test_comparison_vs_sdpa(self):
def causal(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
index 7832442..9a51a95 100644
--- a/test/inductor/test_flex_decoding.py
+++ b/test/inductor/test_flex_decoding.py
@@ -284,15 +284,20 @@
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
compiled_sdpa = torch.compile(sdpa_partial)
- golden_out = sdpa_partial(q_gold, k_gold, v_gold)
- ref_out = sdpa_partial(q_ref, k_ref, v_ref)
- compiled_out = compiled_sdpa(q, k, v)
+ golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
+ ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
+ compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
self._check_out(
golden_out,
ref_out,
compiled_out,
)
+ self._check_out(
+ gold_lse,
+ ref_lse,
+ compiled_lse,
+ )
def run_test_with_call(
self,
@@ -763,6 +768,38 @@
self.run_test(bias_mod)
@supported_platform
+ def test_fully_masked_out_rows_0_check_gqa(self):
+ # Ensure fully masked out rows won't cause NaNs.
+ query = torch.randn(
+ (B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+ )
+ key = torch.randn(
+ (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+ )
+ value = torch.randn(
+ (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
+ )
+
+ M = S // 2
+
+ def mask_mod(b, h, q, kv):
+ return q < M
+
+ block_mask = create_block_mask(mask_mod, 1, 1, S, S)
+
+ flex = torch.compile(flex_attention, dynamic=False)
+
+ out, lse = flex(
+ query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
+ )
+ self.assertEqual(out[:, :, M:, :].sum(), 0)
+ self.assertTrue((lse[:, :, M:] == 0.0).all())
+
+ loss = out.sum() + lse.sum()
+ loss.backward()
+ self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
+
+ @supported_platform
def test_windowed_no_mask_vs_sdpa(self):
score_mod = _generate_windowed(1000)
attention = functools.partial(flex_attention, score_mod=score_mod)
diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py
index 4d76e13..b651d7d 100644
--- a/torch/_higher_order_ops/flex_attention.py
+++ b/torch/_higher_order_ops/flex_attention.py
@@ -204,11 +204,12 @@
mask_mod_other_buffers,
)
- # TODO Unconditionally return logsumexp for backwards
- # if any(t.requires_grad for t in (query, key, value)):
+ # Set fully masked rows' sumexp to 0.0
logsumexp = post_mod_scores.logsumexp(dim=-1)
+ masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
+ logsumexp = torch.where(masked_rows, 0.0, logsumexp)
- post_mod_scores = post_mod_scores.softmax(dim=-1)
+ post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index 8a12860..05e9f1c 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -302,8 +302,13 @@
)
- # Store output and logsumexp
- l_i = tl.where(l_i == 0, 1, l_i)
+ # [Note] Handle fully masked out rows:
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
+ l_i = tl.where(l_i == 0.0, 1, l_i)
+ masked_out_rows = (m_i == float("-inf"))
+ m_i = tl.where(masked_out_rows, 0, m_i)
+
acc = acc / l_i[:, None]
idx_z = tl.program_id(1) // HQ
idx_hq = tl.program_id(1) % HQ
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
index 249b9b8..96e1dbc 100644
--- a/torch/_inductor/kernel/flex_decoding.py
+++ b/torch/_inductor/kernel/flex_decoding.py
@@ -524,11 +524,17 @@
# Reduction
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
+ # See [Note] Handle fully masked out rows:
+ # g_M Is the global max among split kv blocks.
+ masked_rows = lowerings[aten.eq](g_M, -float("inf"))
+ g_M = lowerings[aten.where](masked_rows, 0.0, g_M)
adj_M = lowerings[aten.sub](buf_M, g_M)
alpha = lowerings[aten.exp2](adj_M)
buf_L = lowerings[aten.mul](buf_L, alpha)
g_L = lowerings[aten.sum](buf_L, axis=1)
+ masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
+ g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
logsumexp = lowerings[aten.log2](g_L)
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))