match sdpa patterns from HF (#100609)

Adds sdpa patterns seen in HF models.

To actually make the patterns match, we need constant folding to remove addition of all-zeros mask, and figure out what to do with low mem dropout.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100609
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py
index 03060c9..39f0600 100644
--- a/test/inductor/test_fused_attention.py
+++ b/test/inductor/test_fused_attention.py
@@ -8,7 +8,10 @@
 from torch._dynamo.utils import counters
 from torch._inductor import config
 from torch._inductor.utils import run_and_get_code
-from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_SDPA
+from torch.testing._internal.common_cuda import (
+    PLATFORM_SUPPORTS_FUSED_SDPA,
+    SM80OrLater,
+)
 from torch.testing._internal.common_utils import IS_LINUX, TEST_WITH_ROCM
 from torch.testing._internal.inductor_utils import HAS_CUDA
 
@@ -58,7 +61,8 @@
             if contains:
                 # many of the patterns get re-expanded in dispatcher
                 self.assertIn(
-                    "aten._scaled_dot_product_efficient_attention", source_code
+                    "aten._scaled_dot_product",
+                    source_code,
                 )
             self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
 
@@ -194,6 +198,88 @@
 
         self._check_common(sfdp_pattern_6, contains=False)
 
+    def test_sdpa_rewriter_7(self):
+        def sfdp_pattern_7(query, key, value):
+            q = query.permute(0, 2, 1, 3)
+            k = key.permute(0, 2, 1, 3)
+            v = value.permute(0, 2, 1, 3)
+            div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+            div = div.to(torch.float32)
+            attn_weight = torch.softmax(div, dim=-1)
+            # very small dropout to make sure test passes
+            attn_weight = torch.dropout(attn_weight, 0.0001, True)
+            attn_weight = attn_weight.to(torch.float16)
+            return attn_weight @ v
+
+        args = (
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+        )
+
+        self._check_common(sfdp_pattern_7, args, contains=SM80OrLater)
+
+    def test_sdpa_rewriter_8(self):
+        def sfdp_pattern_8(query, key, value):
+            q = query.permute(0, 2, 1, 3)
+            k = key.permute(0, 2, 1, 3)
+            v = value.permute(0, 2, 1, 3)
+            div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+            div = div.to(torch.float32)
+            attn_weight = torch.softmax(div, dim=-1)
+            attn_weight = attn_weight.to(torch.float16)
+            return attn_weight @ v
+
+        args = (
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+        )
+
+        self._check_common(sfdp_pattern_8, args)
+
+    def test_sdpa_rewriter_9(self):
+        def sfdp_pattern_9(query, key, value):
+            q = query.permute(0, 2, 1, 3)
+            k = key.permute(0, 2, 1, 3)
+            v = value.permute(0, 2, 1, 3)
+            q = q / math.sqrt(q.size(-1))
+            div = q @ k.transpose(-2, -1)
+            div = div.to(torch.float32)
+            attn_weight = torch.softmax(div, dim=-1)
+            # very low dropout to make test pass
+            attn_weight = torch.dropout(attn_weight, 0.0001, True)
+            attn_weight = attn_weight.to(torch.float16)
+            return attn_weight @ v
+
+        args = (
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+        )
+
+        self._check_common(sfdp_pattern_9, args, contains=SM80OrLater)
+
+    def test_sdpa_rewriter_10(self):
+        def sfdp_pattern_10(query, key, value):
+            q = query.permute(0, 2, 1, 3)
+            k = key.permute(0, 2, 1, 3)
+            v = value.permute(0, 2, 1, 3)
+            q = q / math.sqrt(q.size(-1))
+            div = q @ k.transpose(-2, -1)
+            div = div.to(torch.float32)
+            attn_weight = torch.softmax(div, dim=-1)
+            attn_weight = attn_weight.to(torch.float16)
+            return attn_weight @ v
+
+        args = (
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+            torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+        )
+
+        self._check_common(sfdp_pattern_10, args)
+
     @config.patch(fallback_random=True, lowmem_dropout=False)
     def test_pattern_fails_with_tensor_factor(self):
         # https://github.com/pytorch/pytorch/issues/99124
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index f3b92c4..019a4d7 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -142,7 +142,124 @@
     )
 
 
-# TODO(jansel): add more patterns based on what we see in real models
+def _sfdp_pattern_7(query, key, value, dropout_p):
+    # in real workloads inputs to matmul are permuted
+    # causing matmul to expand to a series of expand and clone calls
+    # we want the same to happen during pattern tracing
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = torch.dropout(attn_weight, dropout_p, True)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_7(query, key, value, dropout_p):
+    # sdpa prefers inputs in permuted format
+    # it makes a copy to put them in this format
+    # if they aren't already
+    # to make replacement efficient ensure that inputs to sdpa
+    # are in required order
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return aten.scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=dropout_p,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_8(query, key, value):
+    # no dropout version of pattern 7
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_8(query, key, value):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return aten.scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=0.0,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_9(query, key, value, dropout_p):
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    q = q / math.sqrt(q.size(-1))
+    div = q @ k.transpose(-2, -1)
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = torch.dropout(attn_weight, dropout_p, True)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_9(query, key, value, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return aten.scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=dropout_p,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_10(query, key, value):
+    # no dropout version of 9
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    q = q / math.sqrt(q.size(-1))
+    div = q @ k.transpose(-2, -1)
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_10(query, key, value):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return aten.scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=0.0,
+        is_causal=False,
+    )
+
+
 # TODO(jansel): make these pattern work with lowmem_dropout=True
 
 
@@ -200,6 +317,9 @@
     # sizes/values don't actually matter for initial trace
     # once we get a possible match we re-trace with the actual values and verify the match still holds
     g = functools.partial(torch.empty, (2, 4, 8, 16), device=device, requires_grad=True)
+    gp = functools.partial(
+        torch.empty, (2, 8, 4, 16), device=device, requires_grad=True, dtype=torch.half
+    )
     b = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
     c = functools.partial(torch.tensor, 2.0, device=device)
     # workaround https://github.com/pytorch/pytorch/issues/97894
@@ -249,6 +369,34 @@
             d,
             _sfdp_params_check,
         ),
+        (
+            _sfdp_pattern_7,
+            _sfdp_replacement_7,
+            [gp(), gp(), gp()],
+            d,
+            _sfdp_params_check,
+        ),
+        (
+            _sfdp_pattern_8,
+            _sfdp_replacement_8,
+            [gp(), gp(), gp()],
+            {},
+            _sfdp_params_check,
+        ),
+        (
+            _sfdp_pattern_9,
+            _sfdp_replacement_9,
+            [gp(), gp(), gp()],
+            d,
+            _sfdp_params_check,
+        ),
+        (
+            _sfdp_pattern_10,
+            _sfdp_replacement_10,
+            [gp(), gp(), gp()],
+            {},
+            _sfdp_params_check,
+        ),
     ]:
         args = [*args, *workaround.values()]
         register_replacement(
@@ -269,3 +417,5 @@
             extra_check=extra_check,
             scalar_workaround=workaround,
         )
+
+    counters["inductor"].clear()  # clear view matches encountered during sdpa tracing
diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py
index c2dded8..479d3e7 100644
--- a/torch/_inductor/fx_passes/joint_graph.py
+++ b/torch/_inductor/fx_passes/joint_graph.py
@@ -66,3 +66,17 @@
         repl.meta.update(node.meta)
         node.replace_all_uses_with(repl)
         match.erase_nodes(graph)
+
+
+@register_graph_pattern(
+    CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
+    pass_dict=patterns,
+)
+def pointless_view(match: Match, arg, size):
+    """Remove no-op view"""
+    graph = match.graph
+    node = match.output_node()
+    arg_size = list(node.args[0].meta["val"].shape)
+    if size == arg_size:
+        node.replace_all_uses_with(node.args[0])
+        match.erase_nodes(graph)
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index 6b5240a..e2f4a0a 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -842,6 +842,18 @@
             enable_log=False,
         )(*args)
 
+    from .fx_passes.joint_graph import pointless_view
+
+    matcher_pass = PatternMatcherPass()
+
+    pattern = CallFunction(
+        torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
+    )
+    GraphPatternEntry(
+        pattern=pattern, handler=pointless_view, extra_check=_return_true
+    ).register(matcher_pass.patterns)
+    matcher_pass.apply(gm.graph)
+
     # remove in/out specs
     gm.graph._codegen = torch.fx.graph.CodeGen()
     gm.graph.eliminate_dead_code()