Added and_masks and or_masks utilities (#131073)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131073
Approved by: https://github.com/drisspg
ghstack dependencies: #130871, #130904
diff --git a/docs/source/nn.attention.flex_attention.rst b/docs/source/nn.attention.flex_attention.rst
index 101e88d..f1c209c 100644
--- a/docs/source/nn.attention.flex_attention.rst
+++ b/docs/source/nn.attention.flex_attention.rst
@@ -14,6 +14,8 @@
 
 .. autofunction:: create_block_mask
 .. autofunction:: create_mask
+.. autofunction:: and_masks
+.. autofunction:: or_masks
 
 BlockMask
 ---------
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 1cae48d..685f2c3 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -17,9 +17,12 @@
 from torch.nn.attention.flex_attention import (
     _create_empty_block_mask,
     _identity,
+    and_masks,
     BlockMask,
     create_block_mask,
     flex_attention,
+    noop_mask,
+    or_masks,
 )
 from torch.testing import FileCheck
 from torch.testing._internal import common_utils
@@ -1011,6 +1014,35 @@
         self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
 
     @supported_platform
+    def test_mask_mod_combiners(self):
+        def causal_mask(b, h, q, kv):
+            return q >= kv
+
+        def neg_causal_mask(b, h, q, kv):
+            return q < kv
+
+        def sliding_window(b, h, q, kv):
+            return (q - kv) <= 512
+
+        block_mask = create_block_mask(
+            and_masks(causal_mask, sliding_window), 1, 1, S, S
+        )
+        self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""")
+        attention = functools.partial(flex_attention, block_mask=block_mask)
+        self.run_test_with_call(attention)
+
+        block_mask = create_block_mask(
+            and_masks(causal_mask, neg_causal_mask), 1, 1, S, S
+        )
+        self.assertEqual(block_mask.kv_num_blocks.sum(), 0)
+
+        block_mask1 = create_block_mask(
+            or_masks(causal_mask, neg_causal_mask), 1, 1, S, S
+        )
+        block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
+        self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity())
+
+    @supported_platform
     def test_epilogue_fused(self):
         @torch.compile
         def f(q, k, v):
diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py
index d273e79..a85557c 100644
--- a/torch/_higher_order_ops/flex_attention.py
+++ b/torch/_higher_order_ops/flex_attention.py
@@ -145,8 +145,6 @@
     mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
     mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
 
-    # todo: We wouldn't need these overrides in this file if Dynamo always did the
-    # rewriting.
     with TransformGetItemToIndex():
         scores = (scores * scale).to(working_precision)
         post_mod_scores = torch.where(
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index 8dfe6d3..ae2ea9f 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -1142,7 +1142,6 @@
         Di = tl.load(DELTA + offs_m1)
         # Compute dP and dS.
         dpT = tl.dot(v, tl.trans(do))
-        # dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
         dsT = pT * (dpT - Di[None, :])
         # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
         m = offs_m1[None, :]
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index 3d3d521..168e211 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -22,7 +22,14 @@
 )
 from torch.nn.attention._utils import _validate_sdpa_input
 
-__all__ = ["BlockMask", "flex_attention", "create_block_mask", "create_mask"]
+__all__ = [
+    "BlockMask",
+    "flex_attention",
+    "create_block_mask",
+    "create_mask",
+    "or_masks",
+    "and_masks",
+]
 
 _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
 _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
@@ -37,6 +44,7 @@
 
     SCORE_MOD = 1
     MASK_MOD = 2
+    UNKNOWN = 3
 
 
 @torch._dynamo.assume_constant_result
@@ -58,7 +66,7 @@
     elif num_positional_args == 4:
         return _ModificationType.MASK_MOD
     else:
-        raise AssertionError
+        return _ModificationType.UNKNOWN
 
 
 # Need to define it here so that Dynamo doesn't skip it
@@ -108,13 +116,13 @@
     return score
 
 
-def _no_mask(
+def noop_mask(
     batch: Tensor,
     head: Tensor,
     token_q: Tensor,
     token_kv: Tensor,
 ) -> Tensor:
-    return token_q.new_ones(size=(), dtype=torch.bool, device=batch.device)
+    return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
 
 
 _DEFAULT_SPARSE_BLOCK_SIZE = 128
@@ -266,7 +274,7 @@
             BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
         self.BLOCK_SIZE = BLOCK_SIZE
         if mask_mod is None:
-            mask_mod = _no_mask
+            mask_mod = noop_mask
         self.mask_mod = mask_mod
 
     def as_tuple(self):
@@ -476,6 +484,34 @@
         return partial_blocks, None
 
 
+def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
+    """Returns a mask_mod that's the union of provided mask_mods"""
+    if not all(callable(arg) for arg in mask_mods):
+        raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
+
+    def or_mask(b, h, q_idx, kv_idx):
+        result = b.new_zeros((), dtype=torch.bool)
+        for mask in mask_mods:
+            result = result | mask(b, h, q_idx, kv_idx)
+        return result
+
+    return or_mask
+
+
+def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
+    """Returns a mask_mod that's the intersection of provided mask_mods"""
+    if not all(callable(arg) for arg in mask_mods):
+        raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
+
+    def and_mask(b, h, q_idx, kv_idx):
+        result = b.new_ones((), dtype=torch.bool)
+        for mask in mask_mods:
+            result = result & mask(b, h, q_idx, kv_idx)
+        return result
+
+    return and_mask
+
+
 def _convert_block_mask_to_mask(
     block_mask,
     KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
@@ -640,7 +676,7 @@
     mod_type = _get_mod_type(mask_mod)
     assert (
         mod_type == _ModificationType.MASK_MOD
-    ), "create-block_mask requires a mask_mod function!"
+    ), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
     inner_func = _create_block_mask_inner
     Q_LEN = round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
     KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)