[Inductor][FlexAttention] Correct partial/full blocks naming (#131993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131993
Approved by: https://github.com/drisspg
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index 3cce513..2829419 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -538,19 +538,19 @@
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
) -> BlockMask:
- full_blocks, partial_blocks = block_mask
+ partial_blocks, full_blocks = block_mask
- full_bm = _dense_to_ordered(full_blocks)
- if partial_blocks is not None:
- partial_bm = _dense_to_ordered(partial_blocks)
+ partial_bm = _dense_to_ordered(partial_blocks)
+ if full_blocks is not None:
+ full_bm = _dense_to_ordered(full_blocks)
else:
- partial_bm = (None, None)
+ full_bm = (None, None)
return BlockMask( # type: ignore[call-arg]
- full_bm[0],
- full_bm[1],
partial_bm[0],
partial_bm[1],
+ full_bm[0],
+ full_bm[1],
BLOCK_SIZE=(KV_BLOCK_SIZE, Q_BLOCK_SIZE),
mask_mod=mask_mod,
)
@@ -622,14 +622,14 @@
with the __torch_function__ mode.
"""
mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
- full_block_mask, partial_block_mask = _convert_mask_to_block_mask(
+ partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
mask_tensor,
KV_BLOCK_SIZE=KV_BLOCK_SIZE,
Q_BLOCK_SIZE=Q_BLOCK_SIZE,
separate_full_blocks=True,
)
return _create_sparse_block_from_block_mask(
- (full_block_mask, partial_block_mask), mask_mod
+ (partial_block_mask, full_block_mask), mask_mod
)