Trace attention inference patterns with p=0, cleanup (#109118)
When dropout is traced in inference, it creates a clone() instead of training pattern of rand() etc. This was partially addressed by manually https://github.com/pytorch/pytorch/pull/108141, however that did not cover all of the patterns that included dropout, and there is no reason we should have to specify them manually.
This updates the inference patterns generated to trace with dropout_p = 0.0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109118
Approved by: https://github.com/drisspg, https://github.com/Valentine233
diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py
index ffbcf26..cf6699b 100644
--- a/test/inductor/test_fused_attention.py
+++ b/test/inductor/test_fused_attention.py
@@ -32,7 +32,7 @@
return x
return x.clone()
- return tuple(clone(x) for x in inputs)
+ return [clone(x) for x in inputs]
def _check_common(
self,
@@ -42,6 +42,8 @@
atol=1e-5,
has_fuse_pattern=True,
has_dropout=False,
+ check_train=True,
+ override_check_equal=False,
):
if args1 is None:
tensor_shape = (4, 2, 16, 32)
@@ -50,20 +52,24 @@
torch.randn(tensor_shape, device=self.device),
torch.randn(tensor_shape, device=self.device),
]
+ else:
+ args1 = list(args1)
args2 = self._clone_inputs(args1)
- for training in [False, True]:
+ for training in [False, True] if check_train else [False]:
for x in itertools.chain(args1[:], args2[:]):
if isinstance(x, torch.Tensor) and x.is_floating_point():
x.requires_grad = training
+ dropout_arg = [training] if has_dropout else []
torch.manual_seed(1234)
- result1 = dot_prod_attention(*args1)
+ result1 = dot_prod_attention(*(args1 + dropout_arg))
counters.clear()
torch.manual_seed(1234)
result2, (source_code,) = run_and_get_code(
- torch.compile(dot_prod_attention, fullgraph=True), *args2
+ torch.compile(dot_prod_attention, fullgraph=True),
+ *(args2 + dropout_arg),
)
if has_fuse_pattern:
self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1)
@@ -73,7 +79,9 @@
"aten._scaled_dot_product",
source_code,
)
- if not has_dropout:
+
+ # some tests configured with very low dropout where we still want to check equality
+ if not has_dropout or override_check_equal:
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
if training:
@@ -83,7 +91,7 @@
if (
isinstance(arg1, torch.Tensor)
and arg1.is_floating_point()
- and not has_dropout
+ and (not has_dropout or override_check_equal)
):
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=1.3e-6)
@@ -147,12 +155,12 @@
def _test_sdpa_rewriter_3(self):
def dot_prod_attention(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1),
p=0.4,
- training=True,
+ training=training,
inplace=False,
).matmul(value)
@@ -163,13 +171,16 @@
def _test_sdpa_rewriter_4(self):
def dot_prod_attention(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ training: bool,
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1),
p=0.2,
- training=True,
inplace=False,
+ training=training,
).matmul(value)
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
@@ -209,7 +220,7 @@
@skipIfRocm
def _test_sdpa_rewriter_6(self):
- def sfdp_pattern_6(query, key, value):
+ def sfdp_pattern_6(query, key, value, training):
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).tril(diagonal=0)
@@ -220,7 +231,7 @@
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
- attn_weight = torch.dropout(attn_weight, 0.5, True)
+ attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training)
return attn_weight @ value
self._check_common(sfdp_pattern_6, contains=False, has_dropout=True)
@@ -230,7 +241,7 @@
@skipIfRocm
def _test_sdpa_rewriter_7(self):
- def sfdp_pattern_7(query, key, value):
+ def sfdp_pattern_7(query, key, value, training):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
@@ -238,7 +249,7 @@
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# Set to False
- attn_weight = torch.dropout(attn_weight, 0.00000000001, True)
+ attn_weight = torch.dropout(attn_weight, 0.00000000001, training)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
@@ -247,7 +258,14 @@
torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
)
- self._check_common(sfdp_pattern_7, args, contains=SM80OrLater, atol=2e-3)
+ self._check_common(
+ sfdp_pattern_7,
+ args,
+ contains=SM80OrLater,
+ has_dropout=True,
+ override_check_equal=True,
+ atol=2e-3,
+ )
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
@@ -255,7 +273,12 @@
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(
- checkpoint_wrapper(sfdp_pattern_7), args, contains=SM80OrLater, atol=2e-3
+ checkpoint_wrapper(sfdp_pattern_7),
+ args,
+ contains=SM80OrLater,
+ has_dropout=True,
+ override_check_equal=True,
+ atol=2e-3,
)
@skipIfRocm
@@ -286,7 +309,7 @@
@skipIfRocm
def _test_sdpa_rewriter_9(self):
- def sfdp_pattern_9(query, key, value):
+ def sfdp_pattern_9(query, key, value, training):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
@@ -295,7 +318,7 @@
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# very low dropout to make test pass
- attn_weight = torch.dropout(attn_weight, 0.00000000001, True)
+ attn_weight = torch.dropout(attn_weight, 0.00000000001, training)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
@@ -304,14 +327,26 @@
torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
)
- self._check_common(sfdp_pattern_9, args, contains=SM80OrLater, atol=2e-3)
+ self._check_common(
+ sfdp_pattern_9,
+ args,
+ contains=SM80OrLater,
+ has_dropout=True,
+ override_check_equal=True,
+ atol=2e-3,
+ )
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(
- checkpoint_wrapper(sfdp_pattern_9), args, contains=SM80OrLater, atol=2e-3
+ checkpoint_wrapper(sfdp_pattern_9),
+ args,
+ contains=SM80OrLater,
+ has_dropout=True,
+ override_check_equal=True,
+ atol=2e-3,
)
@skipIfRocm
@@ -425,7 +460,10 @@
def _test_sdpa_rewriter_12(self):
def dot_prod_attention(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ training: bool,
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
q = query.transpose(1, 2)
@@ -437,7 +475,7 @@
.softmax(dim=-1)
.matmul(v),
p=0.4,
- training=True,
+ training=training,
inplace=False,
)
@@ -457,8 +495,8 @@
.matmul(value)
)
- self._check_common(dot_prod_attention)
- self._check_common(checkpoint_wrapper(dot_prod_attention))
+ self._check_common(dot_prod_attention, check_train=False)
+ self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
@skipIfRocm
def _test_sdpa_rewriter_14(self):
@@ -473,8 +511,8 @@
.matmul(value)
)
- self._check_common(dot_prod_attention)
- self._check_common(checkpoint_wrapper(dot_prod_attention))
+ self._check_common(dot_prod_attention, check_train=False)
+ self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
@skipIfRocm
def _test_sdpa_rewriter_15(self):
@@ -493,7 +531,7 @@
.matmul(v)
)
- self._check_common(dot_prod_attention)
+ self._check_common(dot_prod_attention, check_train=False)
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index 0a00ada..755b288 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -1,4 +1,5 @@
import functools
+import inspect
import logging
import math
@@ -303,81 +304,6 @@
)
-def _sfdp_pattern_13(query, key, value, inv_scale):
- # dropout would create a clone() if eval() or p = 0
- return (
- torch.matmul(query, key.transpose(-2, -1))
- .div(inv_scale)
- .softmax(dim=-1)
- .clone()
- .matmul(value)
- )
-
-
-def _sfdp_replacement_13(query, key, value, inv_scale):
- counters["inductor"]["fuse_attention"] += 1
- return aten.scaled_dot_product_attention(
- query.contiguous(),
- key.contiguous(),
- value.contiguous(),
- attn_mask=None,
- dropout_p=0.0,
- is_causal=False,
- scale=1.0 / inv_scale,
- )
-
-
-def _sfdp_pattern_14(query, key, value, scale_factor):
- # dropout would create a clone() if eval() or p = 0
- return (
- torch.matmul(query, key.transpose(-2, -1))
- .mul(scale_factor)
- .softmax(dim=-1)
- .clone()
- .matmul(value)
- )
-
-
-def _sfdp_replacement_14(query, key, value, scale_factor):
- counters["inductor"]["fuse_attention"] += 1
- return aten.scaled_dot_product_attention(
- query.contiguous(),
- key.contiguous(),
- value.contiguous(),
- attn_mask=None,
- dropout_p=0.0,
- is_causal=False,
- scale=scale_factor,
- )
-
-
-def _sfdp_pattern_15(query, key, value, inv_scale):
- # dropout would create a clone() if eval() or p = 0
- q = query.permute(0, 2, 1, 3)
- k = key.permute(0, 2, 1, 3)
- v = value.permute(0, 2, 1, 3)
- return (
- torch.matmul(q, k.transpose(-2, -1))
- .div(inv_scale)
- .softmax(dim=-1)
- .clone()
- .matmul(v)
- )
-
-
-def _sfdp_replacement_15(query, key, value, inv_scale):
- counters["inductor"]["fuse_attention"] += 1
- return aten.scaled_dot_product_attention(
- query.transpose(1, 2),
- key.transpose(1, 2),
- value.transpose(1, 2),
- attn_mask=None,
- dropout_p=0.0,
- is_causal=False,
- scale=1.0 / inv_scale,
- )
-
-
def _sfdp_params_check(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
@@ -418,6 +344,28 @@
return fn
+def partialize_and_update_signature(func, **kwargs):
+ """
+ Equivalent to functools.partial but also updates the signature on returned function
+ """
+ original_sig = inspect.signature(func)
+ parameters = original_sig.parameters
+
+ new_parameters = {
+ key: value for key, value in parameters.items() if key not in kwargs
+ }
+ new_sig = inspect.Signature(parameters=list(new_parameters.values()))
+
+ partial_func = functools.partial(func, **kwargs)
+
+ def wrapper(*args, **kwargs):
+ return partial_func(*args, **kwargs)
+
+ wrapper.__signature__ = new_sig # type: ignore[attr-defined]
+
+ return wrapper
+
+
@functools.lru_cache(None)
def _sfdp_init():
from .joint_graph import patterns
@@ -525,38 +473,26 @@
d,
_sfdp_scale_factor_check(aten.div.Tensor),
),
- (
- _sfdp_pattern_13,
- _sfdp_replacement_13,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
- (
- _sfdp_pattern_14,
- _sfdp_replacement_14,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.mul.Tensor),
- ),
- (
- _sfdp_pattern_15,
- _sfdp_replacement_15,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
]:
- args = [*args, *workaround.values()] # type: ignore[attr-defined]
+ training_args = [*args, *workaround.values()] # type: ignore[attr-defined]
register_replacement(
pattern,
replacement,
- args,
+ training_args,
training_graph,
patterns,
extra_check=extra_check,
scalar_workaround=workaround,
)
+
+ if workaround:
+ assert isinstance(workaround, dict)
+ assert len(workaround) == 1 and "dropout_p" in workaround
+ # functools.partial insufficient because we look at signature downstream
+ pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
+ replacement = partialize_and_update_signature(replacement, dropout_p=0.0)
+ workaround = {}
+
register_replacement(
pattern,
replacement,
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index 8fc1e4d..25ae1c8 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -12,6 +12,7 @@
import torch._guards
import torch.fx
import torch.utils._pytree as pytree
+from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._prims_common import is_integer_dtype
from torch.fx import Node
@@ -1002,7 +1003,9 @@
@torch.no_grad()
def inference_graph(fn, args):
"""Build a normalized inference graph, for use with fx_to_pattern"""
- gm = make_fx(fn, select_decomp_table())(*args)
+ # TODO - look into using aot autograd, asserting no mutating ops here
+ with enable_python_dispatcher():
+ gm = make_fx(fn, select_decomp_table())(*args)
gm.graph.eliminate_dead_code()
gm.recompile()
return gm