blob: 7163ac24de5cf366c9ffe37887ec13979a1d6b1d [file] [log] [blame]
# Owner(s): ["module: inductor"]
import itertools
import math
import torch
import torch._inductor.config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_SDPA,
SM80OrLater,
)
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
@config.patch(fallback_random=True)
class TestSDPAPatternRewriter(TestCase):
def _clone_inputs(self, inputs):
def clone(x):
if not isinstance(x, torch.Tensor):
return x
return x.clone()
return tuple(clone(x) for x in inputs)
def _check_common(
self,
dot_prod_attention,
args1=None,
contains=True,
atol=1e-5,
has_fuse_pattern=True,
has_dropout=False,
):
if args1 is None:
tensor_shape = (4, 2, 16, 32)
args1 = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
]
args2 = self._clone_inputs(args1)
for training in [False, True]:
for x in itertools.chain(args1[:], args2[:]):
if isinstance(x, torch.Tensor) and x.is_floating_point():
x.requires_grad = training
torch.manual_seed(1234)
result1 = dot_prod_attention(*args1)
counters.clear()
torch.manual_seed(1234)
result2, (source_code,) = run_and_get_code(
torch.compile(dot_prod_attention, fullgraph=True), *args2
)
if has_fuse_pattern:
self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1)
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"aten._scaled_dot_product",
source_code,
)
if not has_dropout:
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
if training:
result1.sum().backward()
result2.sum().backward()
for arg1, arg2 in zip(args1, args2):
if (
isinstance(arg1, torch.Tensor)
and arg1.is_floating_point()
and not has_dropout
):
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=1.3e-6)
@skipIfRocm
def test_sdpa_rewriter_1(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
return (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
self._check_common(dot_prod_attention)
def test_pattern_fails_with_reuse(self):
"""
This test checks that the replacement is not done
when an intermediate result is being used / returned downstream
"""
@skipIfRocm
@torch.compile(fullgraph=True)
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
attn_weights = (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
)
return attn_weights.matmul(value), attn_weights
tensor_shape = (2, 4, 8, 16)
args = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
]
_, (source_code,) = run_and_get_code(dot_prod_attention, *args)
self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code)
@skipIfRocm
def test_sdpa_rewriter_2(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
self._check_common(dot_prod_attention)
def test_sdpa_rewriter_3(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1),
p=0.4,
training=True,
inplace=False,
).matmul(value)
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
def test_sdpa_rewriter_4(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1),
p=0.2,
training=True,
inplace=False,
).matmul(value)
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
def test_sdpa_rewriter_5(self):
def sfdp_pattern_5_v1(query, key, value):
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).tril(diagonal=0)
attn_mask = attn_mask.masked_fill(
torch.logical_not(attn_mask), -float("inf")
)
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
return attn_weight @ value
def sfdp_pattern_5_v2(query, key, value):
# https://github.com/pytorch/pytorch/issues/100318.
attn_mask = torch.zeros(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).bool()
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
return attn_weight @ value
self._check_common(sfdp_pattern_5_v1, contains=False)
self._check_common(sfdp_pattern_5_v2, contains=False)
@skipIfRocm
def test_sdpa_rewriter_6(self):
def sfdp_pattern_6(query, key, value):
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).tril(diagonal=0)
attn_mask = attn_mask.masked_fill(
torch.logical_not(attn_mask), -float("inf")
)
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
attn_weight = torch.dropout(attn_weight, 0.5, True)
return attn_weight @ value
self._check_common(sfdp_pattern_6, contains=False, has_dropout=True)
@skipIfRocm
def test_sdpa_rewriter_7(self):
def sfdp_pattern_7(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# very small dropout to make sure test passes
attn_weight = torch.dropout(attn_weight, 0.0000, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(sfdp_pattern_7, args, contains=SM80OrLater, atol=2e-3)
@skipIfRocm
def test_sdpa_rewriter_8(self):
def sfdp_pattern_8(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(sfdp_pattern_8, args, atol=2e-3)
@skipIfRocm
def test_sdpa_rewriter_9(self):
def sfdp_pattern_9(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# very low dropout to make test pass
attn_weight = torch.dropout(attn_weight, 0.9999, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(sfdp_pattern_9, args, contains=SM80OrLater, atol=2e-3)
@skipIfRocm
def test_sdpa_rewriter_10(self):
def sfdp_pattern_10(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
)
self._check_common(sfdp_pattern_10, args, atol=2e-3)
def test_pattern_fails_with_tensor_factor(self):
# https://github.com/pytorch/pytorch/issues/99124
class Model(torch.nn.Module):
def __init__(self, is_inv_factor):
super().__init__()
self.is_inv_factor = is_inv_factor
def forward(self, query, key, value, scale_factor) -> torch.Tensor:
y = torch.matmul(query, key.transpose(-2, -1))
if self.is_inv_factor:
y = y.div(scale_factor)
else:
y = y.mul(scale_factor)
return y.softmax(dim=-1).matmul(value)
tensor_shape = (2, 4, 4, 4)
for is_inv_factor in [True, False]:
args = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn((4, 1, 1), device="cuda"),
]
model = Model(is_inv_factor).eval()
# The training path has an accuracy gap compared with eager mode.
self._check_common(
model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False
)
def test_pattern_fails_with_unsupported_mask(self):
# https://github.com/pytorch/pytorch/issues/100315
class Model(torch.nn.Module):
def __init__(
self,
):
super().__init__()
def forward(self, query, key, value, attn_mask) -> torch.Tensor:
attn_weight = torch.softmax(
query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
+ attn_mask,
dim=-1,
)
return attn_weight @ value
tensor_shape = (2, 4, 4, 4)
upsupported_masks = [
torch.randn((2, 4, 4, 4), device="cuda").to(dtype=torch.int),
2.0,
]
for atte_mask in upsupported_masks:
args = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
atte_mask,
]
model = Model().eval()
# The training path has an accuracy gap compared with eager mode.
self._check_common(
model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False
)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA:
run_tests()