Added and_masks and or_masks utilities (#131073)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131073
Approved by: https://github.com/drisspg
ghstack dependencies: #130871, #130904
diff --git a/docs/source/nn.attention.flex_attention.rst b/docs/source/nn.attention.flex_attention.rst
index 101e88d..f1c209c 100644
--- a/docs/source/nn.attention.flex_attention.rst
+++ b/docs/source/nn.attention.flex_attention.rst
@@ -14,6 +14,8 @@
.. autofunction:: create_block_mask
.. autofunction:: create_mask
+.. autofunction:: and_masks
+.. autofunction:: or_masks
BlockMask
---------
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 1cae48d..685f2c3 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -17,9 +17,12 @@
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
_identity,
+ and_masks,
BlockMask,
create_block_mask,
flex_attention,
+ noop_mask,
+ or_masks,
)
from torch.testing import FileCheck
from torch.testing._internal import common_utils
@@ -1011,6 +1014,35 @@
self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
@supported_platform
+ def test_mask_mod_combiners(self):
+ def causal_mask(b, h, q, kv):
+ return q >= kv
+
+ def neg_causal_mask(b, h, q, kv):
+ return q < kv
+
+ def sliding_window(b, h, q, kv):
+ return (q - kv) <= 512
+
+ block_mask = create_block_mask(
+ and_masks(causal_mask, sliding_window), 1, 1, S, S
+ )
+ self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""")
+ attention = functools.partial(flex_attention, block_mask=block_mask)
+ self.run_test_with_call(attention)
+
+ block_mask = create_block_mask(
+ and_masks(causal_mask, neg_causal_mask), 1, 1, S, S
+ )
+ self.assertEqual(block_mask.kv_num_blocks.sum(), 0)
+
+ block_mask1 = create_block_mask(
+ or_masks(causal_mask, neg_causal_mask), 1, 1, S, S
+ )
+ block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
+ self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity())
+
+ @supported_platform
def test_epilogue_fused(self):
@torch.compile
def f(q, k, v):
diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py
index d273e79..a85557c 100644
--- a/torch/_higher_order_ops/flex_attention.py
+++ b/torch/_higher_order_ops/flex_attention.py
@@ -145,8 +145,6 @@
mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
- # todo: We wouldn't need these overrides in this file if Dynamo always did the
- # rewriting.
with TransformGetItemToIndex():
scores = (scores * scale).to(working_precision)
post_mod_scores = torch.where(
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index 8dfe6d3..ae2ea9f 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -1142,7 +1142,6 @@
Di = tl.load(DELTA + offs_m1)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
- # dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
m = offs_m1[None, :]
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index 3d3d521..168e211 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -22,7 +22,14 @@
)
from torch.nn.attention._utils import _validate_sdpa_input
-__all__ = ["BlockMask", "flex_attention", "create_block_mask", "create_mask"]
+__all__ = [
+ "BlockMask",
+ "flex_attention",
+ "create_block_mask",
+ "create_mask",
+ "or_masks",
+ "and_masks",
+]
_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
@@ -37,6 +44,7 @@
SCORE_MOD = 1
MASK_MOD = 2
+ UNKNOWN = 3
@torch._dynamo.assume_constant_result
@@ -58,7 +66,7 @@
elif num_positional_args == 4:
return _ModificationType.MASK_MOD
else:
- raise AssertionError
+ return _ModificationType.UNKNOWN
# Need to define it here so that Dynamo doesn't skip it
@@ -108,13 +116,13 @@
return score
-def _no_mask(
+def noop_mask(
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
- return token_q.new_ones(size=(), dtype=torch.bool, device=batch.device)
+ return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
_DEFAULT_SPARSE_BLOCK_SIZE = 128
@@ -266,7 +274,7 @@
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
self.BLOCK_SIZE = BLOCK_SIZE
if mask_mod is None:
- mask_mod = _no_mask
+ mask_mod = noop_mask
self.mask_mod = mask_mod
def as_tuple(self):
@@ -476,6 +484,34 @@
return partial_blocks, None
+def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
+ """Returns a mask_mod that's the union of provided mask_mods"""
+ if not all(callable(arg) for arg in mask_mods):
+ raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
+
+ def or_mask(b, h, q_idx, kv_idx):
+ result = b.new_zeros((), dtype=torch.bool)
+ for mask in mask_mods:
+ result = result | mask(b, h, q_idx, kv_idx)
+ return result
+
+ return or_mask
+
+
+def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
+ """Returns a mask_mod that's the intersection of provided mask_mods"""
+ if not all(callable(arg) for arg in mask_mods):
+ raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
+
+ def and_mask(b, h, q_idx, kv_idx):
+ result = b.new_ones((), dtype=torch.bool)
+ for mask in mask_mods:
+ result = result & mask(b, h, q_idx, kv_idx)
+ return result
+
+ return and_mask
+
+
def _convert_block_mask_to_mask(
block_mask,
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
@@ -640,7 +676,7 @@
mod_type = _get_mod_type(mask_mod)
assert (
mod_type == _ModificationType.MASK_MOD
- ), "create-block_mask requires a mask_mod function!"
+ ), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
inner_func = _create_block_mask_inner
Q_LEN = round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)