match sdpa patterns from HF (#100609)
Adds sdpa patterns seen in HF models.
To actually make the patterns match, we need constant folding to remove addition of all-zeros mask, and figure out what to do with low mem dropout.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100609
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py
index 03060c9..39f0600 100644
--- a/test/inductor/test_fused_attention.py
+++ b/test/inductor/test_fused_attention.py
@@ -8,7 +8,10 @@
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
-from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_SDPA
+from torch.testing._internal.common_cuda import (
+ PLATFORM_SUPPORTS_FUSED_SDPA,
+ SM80OrLater,
+)
from torch.testing._internal.common_utils import IS_LINUX, TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import HAS_CUDA
@@ -58,7 +61,8 @@
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
- "aten._scaled_dot_product_efficient_attention", source_code
+ "aten._scaled_dot_product",
+ source_code,
)
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
@@ -194,6 +198,88 @@
self._check_common(sfdp_pattern_6, contains=False)
+ def test_sdpa_rewriter_7(self):
+ def sfdp_pattern_7(query, key, value):
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ # very small dropout to make sure test passes
+ attn_weight = torch.dropout(attn_weight, 0.0001, True)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+ args = (
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ )
+
+ self._check_common(sfdp_pattern_7, args, contains=SM80OrLater)
+
+ def test_sdpa_rewriter_8(self):
+ def sfdp_pattern_8(query, key, value):
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+ args = (
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ )
+
+ self._check_common(sfdp_pattern_8, args)
+
+ def test_sdpa_rewriter_9(self):
+ def sfdp_pattern_9(query, key, value):
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ q = q / math.sqrt(q.size(-1))
+ div = q @ k.transpose(-2, -1)
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ # very low dropout to make test pass
+ attn_weight = torch.dropout(attn_weight, 0.0001, True)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+ args = (
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ )
+
+ self._check_common(sfdp_pattern_9, args, contains=SM80OrLater)
+
+ def test_sdpa_rewriter_10(self):
+ def sfdp_pattern_10(query, key, value):
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ q = q / math.sqrt(q.size(-1))
+ div = q @ k.transpose(-2, -1)
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+ args = (
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
+ )
+
+ self._check_common(sfdp_pattern_10, args)
+
@config.patch(fallback_random=True, lowmem_dropout=False)
def test_pattern_fails_with_tensor_factor(self):
# https://github.com/pytorch/pytorch/issues/99124
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index f3b92c4..019a4d7 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -142,7 +142,124 @@
)
-# TODO(jansel): add more patterns based on what we see in real models
+def _sfdp_pattern_7(query, key, value, dropout_p):
+ # in real workloads inputs to matmul are permuted
+ # causing matmul to expand to a series of expand and clone calls
+ # we want the same to happen during pattern tracing
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+
+def _sfdp_replacement_7(query, key, value, dropout_p):
+ # sdpa prefers inputs in permuted format
+ # it makes a copy to put them in this format
+ # if they aren't already
+ # to make replacement efficient ensure that inputs to sdpa
+ # are in required order
+ counters["inductor"]["fuse_attention"] += 1
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ return aten.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None, # attn_mask,
+ dropout_p=dropout_p,
+ is_causal=False,
+ )
+
+
+def _sfdp_pattern_8(query, key, value):
+ # no dropout version of pattern 7
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+
+def _sfdp_replacement_8(query, key, value):
+ counters["inductor"]["fuse_attention"] += 1
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ return aten.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None, # attn_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+
+
+def _sfdp_pattern_9(query, key, value, dropout_p):
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ q = q / math.sqrt(q.size(-1))
+ div = q @ k.transpose(-2, -1)
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+
+def _sfdp_replacement_9(query, key, value, dropout_p):
+ counters["inductor"]["fuse_attention"] += 1
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ return aten.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None, # attn_mask,
+ dropout_p=dropout_p,
+ is_causal=False,
+ )
+
+
+def _sfdp_pattern_10(query, key, value):
+ # no dropout version of 9
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ q = q / math.sqrt(q.size(-1))
+ div = q @ k.transpose(-2, -1)
+ div = div.to(torch.float32)
+ attn_weight = torch.softmax(div, dim=-1)
+ attn_weight = attn_weight.to(torch.float16)
+ return attn_weight @ v
+
+
+def _sfdp_replacement_10(query, key, value):
+ counters["inductor"]["fuse_attention"] += 1
+ q = query.permute(0, 2, 1, 3)
+ k = key.permute(0, 2, 1, 3)
+ v = value.permute(0, 2, 1, 3)
+ return aten.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None, # attn_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+
+
# TODO(jansel): make these pattern work with lowmem_dropout=True
@@ -200,6 +317,9 @@
# sizes/values don't actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
g = functools.partial(torch.empty, (2, 4, 8, 16), device=device, requires_grad=True)
+ gp = functools.partial(
+ torch.empty, (2, 8, 4, 16), device=device, requires_grad=True, dtype=torch.half
+ )
b = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
c = functools.partial(torch.tensor, 2.0, device=device)
# workaround https://github.com/pytorch/pytorch/issues/97894
@@ -249,6 +369,34 @@
d,
_sfdp_params_check,
),
+ (
+ _sfdp_pattern_7,
+ _sfdp_replacement_7,
+ [gp(), gp(), gp()],
+ d,
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_8,
+ _sfdp_replacement_8,
+ [gp(), gp(), gp()],
+ {},
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_9,
+ _sfdp_replacement_9,
+ [gp(), gp(), gp()],
+ d,
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_10,
+ _sfdp_replacement_10,
+ [gp(), gp(), gp()],
+ {},
+ _sfdp_params_check,
+ ),
]:
args = [*args, *workaround.values()]
register_replacement(
@@ -269,3 +417,5 @@
extra_check=extra_check,
scalar_workaround=workaround,
)
+
+ counters["inductor"].clear() # clear view matches encountered during sdpa tracing
diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py
index c2dded8..479d3e7 100644
--- a/torch/_inductor/fx_passes/joint_graph.py
+++ b/torch/_inductor/fx_passes/joint_graph.py
@@ -66,3 +66,17 @@
repl.meta.update(node.meta)
node.replace_all_uses_with(repl)
match.erase_nodes(graph)
+
+
+@register_graph_pattern(
+ CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
+ pass_dict=patterns,
+)
+def pointless_view(match: Match, arg, size):
+ """Remove no-op view"""
+ graph = match.graph
+ node = match.output_node()
+ arg_size = list(node.args[0].meta["val"].shape)
+ if size == arg_size:
+ node.replace_all_uses_with(node.args[0])
+ match.erase_nodes(graph)
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index 6b5240a..e2f4a0a 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -842,6 +842,18 @@
enable_log=False,
)(*args)
+ from .fx_passes.joint_graph import pointless_view
+
+ matcher_pass = PatternMatcherPass()
+
+ pattern = CallFunction(
+ torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
+ )
+ GraphPatternEntry(
+ pattern=pattern, handler=pointless_view, extra_check=_return_true
+ ).register(matcher_pass.patterns)
+ matcher_pass.apply(gm.graph)
+
# remove in/out specs
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.graph.eliminate_dead_code()