| # Owner(s): ["module: inductor"] |
| import functools |
| import itertools |
| import math |
| |
| import torch |
| import torch._inductor.config |
| import torch.utils.checkpoint |
| from torch._dynamo.debug_utils import aot_graph_input_parser |
| from torch._dynamo.utils import counters |
| from torch._inductor.test_case import run_tests, TestCase |
| from torch._inductor.utils import run_and_get_code |
| from torch.testing._internal.common_cuda import ( |
| PLATFORM_SUPPORTS_FUSED_ATTENTION, |
| SM80OrLater, |
| ) |
| from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm |
| from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| |
| |
| def checkpoint_wrapper(fn): |
| def inner(*args): |
| return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) |
| |
| return inner |
| |
| |
| class TestSDPAPatternRewriterTemplate(TestCase): |
| use_static_shapes = True |
| |
| def _clone_inputs(self, inputs): |
| def clone(x): |
| if not isinstance(x, torch.Tensor): |
| return x |
| return x.clone() |
| |
| return [clone(x) for x in inputs] |
| |
| def _check_common( |
| self, |
| dot_prod_attention, |
| args1=None, |
| contains=True, |
| atol=1e-5, |
| has_fuse_pattern=True, |
| has_dropout=False, |
| check_train=True, |
| override_check_equal=False, |
| dtype=torch.float, |
| rtol=1.3e-6, |
| ): |
| if args1 is None: |
| tensor_shape = (4, 2, 16, 32) |
| args1 = [ |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| ] |
| else: |
| args1 = list(args1) |
| args2 = self._clone_inputs(args1) |
| |
| for training in [False, True] if check_train else [False]: |
| for x in itertools.chain(args1[:], args2[:]): |
| if isinstance(x, torch.Tensor) and x.is_floating_point(): |
| x.requires_grad = training |
| |
| if not self.use_static_shapes: |
| torch._dynamo.mark_dynamic(args2[0], 0) |
| torch._dynamo.mark_dynamic(args2[1], 0) |
| torch._dynamo.mark_dynamic(args2[2], 0) |
| |
| dropout_arg = [training] if has_dropout else [] |
| torch.manual_seed(1234) |
| result1 = dot_prod_attention(*(args1 + dropout_arg)) |
| |
| counters.clear() |
| torch.manual_seed(1234) |
| result2, source_code = run_and_get_code( |
| torch.compile(dot_prod_attention, fullgraph=True), |
| *(args2 + dropout_arg), |
| ) |
| source_code = "\n".join(source_code) |
| if has_fuse_pattern: |
| self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1) |
| if contains: |
| # many of the patterns get re-expanded in dispatcher |
| self.assertIn( |
| "aten._scaled_dot_product", |
| source_code, |
| ) |
| |
| # some tests configured with very low dropout where we still want to check equality |
| if not has_dropout or override_check_equal: |
| self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) |
| |
| if training: |
| result1.sum().backward() |
| result2.sum().backward() |
| for arg1, arg2 in zip(args1, args2): |
| if ( |
| isinstance(arg1, torch.Tensor) |
| and arg1.is_floating_point() |
| and (not has_dropout or override_check_equal) |
| ): |
| self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_1(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" |
| return ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .matmul(value) |
| ) |
| |
| for dtype in [torch.float, torch.half]: |
| atol = 0.001 |
| rtol = 1.3e-6 if dtype == torch.float else 0.7 |
| if self.device == "cpu" and dtype == torch.half: |
| atol = 2e-3 |
| rtol = 1e-2 |
| self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol) |
| self._check_common( |
| checkpoint_wrapper(dot_prod_attention), |
| dtype=dtype, |
| atol=atol, |
| rtol=rtol, |
| ) |
| |
| @skipIfRocm |
| @torch._inductor.config.patch("freezing", True) |
| def _test_sdpa_rewriter_1_freezing(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" |
| return ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .matmul(value) |
| ) |
| |
| for dtype in [torch.float, torch.half]: |
| atol = 0.001 |
| rtol = 1.3e-6 if dtype == torch.float else 0.7 |
| if self.device == "cpu" and dtype == torch.half: |
| atol = 2e-3 |
| rtol = 1e-2 |
| with torch.no_grad(): |
| self._check_common( |
| dot_prod_attention, |
| dtype=dtype, |
| atol=atol, |
| rtol=rtol, |
| check_train=False, |
| ) |
| |
| @skipIfRocm |
| def _test_insignificant_strides(self): |
| f32 = torch.float32 |
| |
| # repro taken from https://github.com/pytorch/pytorch/issues/124289 |
| # constant_pad_nd is a single element tensor that gets expanded |
| |
| def forward( |
| permute_3: "f32[1, 32, 1, 128]", |
| permute_4: "f32[1, 32, 1, 128]", |
| permute_5: "f32[1, 32, 1, 128]", |
| permute_6: "f32[1, 1, 64]", |
| mul_2: "f32[1, 1, 1, 1]", |
| ): |
| cat = torch.ops.aten.cat.default([permute_6, permute_6], 2) |
| permute_6 = None |
| cos = torch.ops.aten.cos.default(cat) |
| sin = torch.ops.aten.sin.default(cat) |
| unsqueeze_10 = torch.ops.aten.unsqueeze.default(cos, 1) |
| cos = None |
| unsqueeze_11 = torch.ops.aten.unsqueeze.default(sin, 1) |
| sin = None |
| mul_5 = torch.ops.aten.mul.Tensor(permute_3, unsqueeze_10) |
| slice_10 = torch.ops.aten.slice.Tensor(permute_3, 3, 0, 64) |
| slice_11 = torch.ops.aten.slice.Tensor( |
| permute_3, 3, 64, 9223372036854775807 |
| ) |
| permute_3 = None |
| neg = torch.ops.aten.neg.default(slice_11) |
| slice_11 = None |
| cat_1 = torch.ops.aten.cat.default([neg, slice_10], 3) |
| neg = slice_10 = None |
| mul_6 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_11) |
| cat_1 = None |
| add_1 = torch.ops.aten.add.Tensor(mul_5, mul_6) |
| mul_5 = mul_6 = None |
| mul_7 = torch.ops.aten.mul.Tensor(permute_4, unsqueeze_10) |
| unsqueeze_10 = None |
| slice_12 = torch.ops.aten.slice.Tensor(permute_4, 3, 0, 64) |
| slice_13 = torch.ops.aten.slice.Tensor( |
| permute_4, 3, 64, 9223372036854775807 |
| ) |
| permute_4 = None |
| neg_1 = torch.ops.aten.neg.default(slice_13) |
| slice_13 = None |
| cat_2 = torch.ops.aten.cat.default([neg_1, slice_12], 3) |
| neg_1 = slice_12 = None |
| mul_8 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_11) |
| cat_2 = unsqueeze_11 = None |
| add_2 = torch.ops.aten.add.Tensor(mul_7, mul_8) |
| mul_7 = mul_8 = None |
| slice_14 = torch.ops.aten.slice.Tensor(mul_2, 0, 0, 9223372036854775807) |
| mul_2 = None |
| slice_15 = torch.ops.aten.slice.Tensor(slice_14, 1, 0, 9223372036854775807) |
| slice_14 = None |
| slice_16 = torch.ops.aten.slice.Tensor(slice_15, 2, 0, 9223372036854775807) |
| slice_15 = None |
| constant_pad_nd = torch.ops.aten.constant_pad_nd.default( |
| slice_16, [0, 7], 0.0 |
| ) |
| slice_16 = None |
| slice_17 = torch.ops.aten.slice.Tensor(constant_pad_nd, -1, 0, 1) |
| constant_pad_nd = None |
| expand_5 = torch.ops.aten.expand.default(slice_17, [1, 32, 1, 1]) |
| _scaled_dot_product_efficient_attention = ( |
| torch.ops.aten._scaled_dot_product_efficient_attention.default( |
| add_1, add_2, permute_5, expand_5, True |
| ) |
| ) |
| return _scaled_dot_product_efficient_attention |
| |
| kwargs = aot_graph_input_parser(forward, device="cuda") |
| # runs successfully |
| out_eager = forward(**kwargs) |
| out_c = torch.compile(forward)(**kwargs) |
| # dont compare philox_seed/offset |
| torch.testing.assert_close(out_eager[0:2], out_c[0:2]) |
| |
| def _test_pattern_fails_with_reuse(self): |
| """ |
| This test checks that the replacement is not done |
| when an intermediate result is being used / returned downstream |
| """ |
| |
| @torch.compile(fullgraph=True) |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| attn_weights = ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| ) |
| return attn_weights.matmul(value), attn_weights |
| |
| tensor_shape = (2, 4, 8, 16) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| ] |
| _, (source_code,) = run_and_get_code(dot_prod_attention, *args) |
| self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_2(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| return ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .mul(1.0 / math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .matmul(value) |
| ) |
| |
| self._check_common(dot_prod_attention) |
| self._check_common(checkpoint_wrapper(dot_prod_attention)) |
| |
| @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 |
| def _test_sdpa_rewriter_3(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool |
| ) -> torch.Tensor: |
| return torch.nn.functional.dropout( |
| torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1), |
| p=0.4, |
| training=training, |
| inplace=False, |
| ).matmul(value) |
| |
| self._check_common(dot_prod_attention, contains=False, has_dropout=True) |
| self._check_common( |
| checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True |
| ) |
| |
| @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 |
| def _test_sdpa_rewriter_4(self): |
| def dot_prod_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| training: bool, |
| ) -> torch.Tensor: |
| return torch.nn.functional.dropout( |
| torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1), |
| p=0.2, |
| inplace=False, |
| training=training, |
| ).matmul(value) |
| |
| self._check_common(dot_prod_attention, contains=False, has_dropout=True) |
| self._check_common( |
| checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True |
| ) |
| |
| def _test_sdpa_rewriter_5(self): |
| def sfdp_pattern_5_v1(query, key, value): |
| attn_mask = torch.ones( |
| query.size(-2), key.size(-2), dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| attn_mask = attn_mask.masked_fill( |
| torch.logical_not(attn_mask), -float("inf") |
| ) |
| attn_weight = torch.softmax( |
| (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, |
| dim=-1, |
| ) |
| return attn_weight @ value |
| |
| def sfdp_pattern_5_v2(query, key, value): |
| # https://github.com/pytorch/pytorch/issues/100318. |
| attn_mask = torch.zeros( |
| query.size(-2), key.size(-2), dtype=torch.bool, device=query.device |
| ).bool() |
| attn_weight = torch.softmax( |
| (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, |
| dim=-1, |
| ) |
| return attn_weight @ value |
| |
| self._check_common(sfdp_pattern_5_v1, contains=False) |
| self._check_common(checkpoint_wrapper(sfdp_pattern_5_v1), contains=False) |
| self._check_common(sfdp_pattern_5_v2, contains=False) |
| self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_6(self): |
| def sfdp_pattern_6(query, key, value, training): |
| attn_mask = torch.ones( |
| query.size(-2), key.size(-2), dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| attn_mask = attn_mask.masked_fill( |
| torch.logical_not(attn_mask), -float("inf") |
| ) |
| attn_weight = torch.softmax( |
| (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, |
| dim=-1, |
| ) |
| attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training) |
| return attn_weight @ value |
| |
| self._check_common(sfdp_pattern_6, contains=False, has_dropout=True) |
| self._check_common( |
| checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_7(self): |
| def sfdp_pattern_7(query, key, value, training): |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) |
| div = div.to(torch.float32) |
| attn_weight = torch.softmax(div, dim=-1) |
| # Set to False |
| attn_weight = torch.dropout(attn_weight, 0.00000000001, training) |
| attn_weight = attn_weight.to(torch.float16) |
| return attn_weight @ v |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| ) |
| self._check_common( |
| sfdp_pattern_7, |
| args, |
| contains=SM80OrLater, |
| has_dropout=True, |
| override_check_equal=True, |
| atol=2e-3, |
| ) |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| ) |
| self._check_common( |
| checkpoint_wrapper(sfdp_pattern_7), |
| args, |
| contains=SM80OrLater, |
| has_dropout=True, |
| override_check_equal=True, |
| atol=2e-3, |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_8(self): |
| def sfdp_pattern_8(query, key, value): |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) |
| div = div.to(torch.float32) |
| attn_weight = torch.softmax(div, dim=-1) |
| attn_weight = attn_weight.to(torch.float16) |
| return attn_weight @ v |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| ) |
| self._check_common(sfdp_pattern_8, args, atol=2e-3) |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| ) |
| self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_9(self): |
| def sfdp_pattern_9(query, key, value, training): |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| q = q / math.sqrt(q.size(-1)) |
| div = q @ k.transpose(-2, -1) |
| div = div.to(torch.float32) |
| attn_weight = torch.softmax(div, dim=-1) |
| # very low dropout to make test pass |
| attn_weight = torch.dropout(attn_weight, 0.00000000001, training) |
| attn_weight = attn_weight.to(torch.float16) |
| return attn_weight @ v |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| ) |
| self._check_common( |
| sfdp_pattern_9, |
| args, |
| contains=SM80OrLater, |
| has_dropout=True, |
| override_check_equal=True, |
| atol=2e-3, |
| ) |
| args = ( |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| ) |
| self._check_common( |
| checkpoint_wrapper(sfdp_pattern_9), |
| args, |
| contains=SM80OrLater, |
| has_dropout=True, |
| override_check_equal=True, |
| atol=2e-3, |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_10(self): |
| def sfdp_pattern_10(query, key, value): |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| q = q / math.sqrt(q.size(-1)) |
| div = q @ k.transpose(-2, -1) |
| div = div.to(torch.float32) |
| attn_weight = torch.softmax(div, dim=-1) |
| attn_weight = attn_weight.to(torch.float16) |
| return attn_weight @ v |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), |
| ) |
| self._check_common(sfdp_pattern_10, args, atol=2e-3) |
| |
| args = ( |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), |
| ) |
| self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3) |
| |
| def _test_pattern_fails_with_tensor_factor(self): |
| # https://github.com/pytorch/pytorch/issues/99124 |
| class Model(torch.nn.Module): |
| def __init__(self, is_inv_factor): |
| super().__init__() |
| self.is_inv_factor = is_inv_factor |
| |
| def forward(self, query, key, value, scale_factor) -> torch.Tensor: |
| # Dividing by scale_factor makes scale_factor gradients very |
| # unstable |
| scale_factor = scale_factor.detach() |
| y = torch.matmul(query, key.transpose(-2, -1)) |
| if self.is_inv_factor: |
| y = y.div(scale_factor) |
| else: |
| y = y.mul(scale_factor) |
| return y.softmax(dim=-1).matmul(value) |
| |
| tensor_shape = (2, 4, 4, 4) |
| for is_inv_factor in [True, False]: |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn((4, 1, 1), device=self.device), |
| ] |
| model = Model(is_inv_factor).eval() |
| # The training path has an accuracy gap compared with eager mode. |
| self._check_common( |
| model, args1=args, contains=False, atol=1e-3, has_fuse_pattern=False |
| ) |
| |
| def _test_pattern_fails_with_unsupported_mask(self): |
| if not self.use_static_shapes: |
| self.skipTest("Causes shape specialization. TODO: investigate") |
| |
| # https://github.com/pytorch/pytorch/issues/100315 |
| class Model(torch.nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
| |
| def forward(self, query, key, value, attn_mask) -> torch.Tensor: |
| attn_weight = torch.softmax( |
| query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) |
| + attn_mask, |
| dim=-1, |
| ) |
| return attn_weight @ value |
| |
| tensor_shape = (2, 4, 4, 4) |
| |
| upsupported_masks = [ |
| torch.randn((2, 4, 4, 4), device=self.device).to(dtype=torch.int), |
| 2.0, |
| ] |
| for atte_mask in upsupported_masks: |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| atte_mask, |
| ] |
| model = Model().eval() |
| # The training path has an accuracy gap compared with eager mode. |
| self._check_common( |
| model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_11(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| q = query.transpose(1, 2) |
| k = key.transpose(1, 2) |
| v = value.transpose(1, 2) |
| return ( |
| torch.matmul(q, k.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .matmul(v) |
| ) |
| |
| self._check_common(dot_prod_attention) |
| |
| def _test_sdpa_rewriter_12(self): |
| def dot_prod_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| training: bool, |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| q = query.transpose(1, 2) |
| k = key.transpose(1, 2) |
| v = value.transpose(1, 2) |
| return torch.nn.functional.dropout( |
| torch.matmul(q, k.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .matmul(v), |
| p=0.4, |
| training=training, |
| inplace=False, |
| ) |
| |
| self._check_common(dot_prod_attention, contains=False, has_dropout=True) |
| |
| @skipIfRocm |
| def _test_sdpa_prev_13(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" |
| return ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .clone() |
| .matmul(value) |
| ) |
| |
| self._check_common(dot_prod_attention, check_train=False) |
| self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) |
| |
| @skipIfRocm |
| def _test_sdpa_prev_14(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| return ( |
| torch.matmul(query, key.transpose(-2, -1)) |
| .mul(1.0 / math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .clone() |
| .matmul(value) |
| ) |
| |
| self._check_common(dot_prod_attention, check_train=False) |
| self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) |
| |
| @skipIfRocm |
| def _test_sdpa_prev_15(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| q = query.transpose(1, 2) |
| k = key.transpose(1, 2) |
| v = value.transpose(1, 2) |
| return ( |
| torch.matmul(q, k.transpose(-2, -1)) |
| .div(math.sqrt(key.shape[-1])) |
| .softmax(dim=-1) |
| .clone() |
| .matmul(v) |
| ) |
| |
| self._check_common(dot_prod_attention, check_train=False) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_13(self, dtype): |
| def dot_prod_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| training: bool, |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) |
| attn_weight = torch.nn.functional.dropout( |
| attn_weight, p=0.5, training=training |
| ) |
| return torch.bmm(attn_weight, value) |
| |
| tensor_shape = (4, 8, 16) |
| args = [ |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| torch.randn(tensor_shape, device=self.device, dtype=dtype), |
| ] |
| |
| self._check_common( |
| dot_prod_attention, |
| check_train=False, |
| args1=args, |
| has_dropout=True, |
| override_check_equal=True, |
| atol=1e-2, |
| rtol=1e-2, |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_14(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| attn_mask = torch.ones( |
| query.size(1), key.size(1), dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| attn_mask = attn_mask.masked_fill( |
| torch.logical_not(attn_mask), -float("inf") |
| ) |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| return ( |
| (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask) |
| .softmax(dim=-1) |
| .matmul(v) |
| ) |
| |
| self._check_common(dot_prod_attention) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_15(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| q = query.transpose(1, 2) |
| k = key.transpose(1, 2) |
| v = value.transpose(1, 2) |
| bs = q.size(0) |
| k_len = k.size(-2) |
| attn_mask = torch.ones( |
| bs, k_len, dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 |
| attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) |
| scores = scores.masked_fill(attn_mask, -float("inf")) |
| weights = torch.nn.functional.softmax(scores, dim=-1) |
| return torch.matmul(weights, v) |
| |
| self._check_common(dot_prod_attention, check_train=False) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_16(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| attn_mask = torch.ones( |
| query.size(1), key.size(1), dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| attn_mask = attn_mask.masked_fill( |
| torch.logical_not(attn_mask), -float("inf") |
| ) |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| return torch.nn.functional.dropout( |
| (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( |
| dim=-1 |
| ), |
| p=0.4, |
| training=training, |
| inplace=False, |
| ).matmul(v) |
| |
| self._check_common(dot_prod_attention, contains=False, has_dropout=True) |
| |
| # also check batch_size=1 because the graph is slightly different |
| tensor_shape = (1, 2, 16, 32) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| ] |
| self._check_common( |
| dot_prod_attention, args1=args, contains=False, has_dropout=True |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_16_fp32_mask(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| attn_mask = torch.randn( |
| query.size(1), key.size(1), dtype=torch.float, device=query.device |
| ).tril(diagonal=0) |
| q = query.permute(0, 2, 1, 3) |
| k = key.permute(0, 2, 1, 3) |
| v = value.permute(0, 2, 1, 3) |
| return torch.nn.functional.dropout( |
| (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( |
| dim=-1 |
| ), |
| p=0.4, |
| training=training, |
| inplace=False, |
| ).matmul(v) |
| |
| self._check_common(dot_prod_attention, contains=False, has_dropout=True) |
| |
| # also check batch_size=1 because the graph is slightly different |
| tensor_shape = (1, 2, 16, 32) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| ] |
| self._check_common( |
| dot_prod_attention, args1=args, contains=False, has_dropout=True |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_17(self): |
| def dot_prod_attention( |
| query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training |
| ) -> torch.Tensor: |
| """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" |
| q = query.transpose(1, 2) |
| k = key.transpose(1, 2) |
| v = value.transpose(1, 2) |
| bs = q.size(0) |
| k_len = k.size(-2) |
| attn_mask = torch.ones( |
| bs, k_len, dtype=torch.bool, device=query.device |
| ).tril(diagonal=0) |
| scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 |
| attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) |
| scores = scores.masked_fill(attn_mask, -float("inf")) |
| weights = torch.nn.functional.softmax(scores, dim=-1) |
| weights = torch.nn.functional.dropout( |
| weights, |
| p=0.4, |
| training=training, |
| inplace=False, |
| ) |
| return torch.matmul(weights, v) |
| |
| self._check_common(dot_prod_attention, check_train=False, has_dropout=True) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_18(self): |
| def dot_prod_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| causal_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| # for hf_GPT2 with dropout |
| query = query.permute([0, 2, 1, 3]) |
| key = key.permute([0, 2, 1, 3]) |
| value = value.permute([0, 2, 1, 3]) |
| attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) |
| inv_scale = torch.full( |
| (), math.sqrt(value.size(-1)), dtype=query.dtype, device=query.device |
| ) |
| attn_weights = attn_weights.div(inv_scale) |
| causal_mask_value = torch.full( |
| (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device |
| ) |
| attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) |
| return ( |
| ( |
| torch.nn.functional.dropout( |
| attn_weights.softmax(dim=-1), 0.0 |
| ).matmul(value) |
| ), |
| key.permute([0, 2, 1, 3]), |
| value.permute([0, 2, 1, 3]), |
| ) |
| |
| tensor_shape = (4, 2, 16, 32) |
| causal_mask = torch.ones(2, 2, dtype=torch.bool, device=self.device).tril( |
| diagonal=0 |
| ) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| causal_mask, |
| ] |
| self._check_common( |
| dot_prod_attention, |
| args1=args, |
| contains=False, |
| has_dropout=False, |
| check_train=False, |
| ) |
| |
| # also check batch_size=1 because the graph is slightly different |
| tensor_shape = (1, 2, 16, 32) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| causal_mask, |
| ] |
| self._check_common( |
| dot_prod_attention, |
| args1=args, |
| contains=False, |
| has_dropout=False, |
| check_train=False, |
| ) |
| |
| @skipIfRocm |
| def _test_sdpa_rewriter_19(self): |
| def dot_prod_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| causal_mask: torch.Tensor, |
| attn_mask: torch.Tensor, |
| training, |
| ) -> torch.Tensor: |
| attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) |
| inv_scale = torch.full( |
| (), |
| math.sqrt(value.size(-1)), |
| dtype=attn_weights.dtype, |
| device=attn_weights.device, |
| ) |
| attn_weights = attn_weights.div(inv_scale) |
| causal_mask_value = torch.full( |
| (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device |
| ) |
| attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) |
| attn_weights = attn_weights + attn_mask |
| attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) |
| return torch.nn.functional.dropout( |
| attn_weights, |
| p=0.4, |
| training=training, |
| inplace=False, |
| ).matmul(value) |
| |
| tensor_shape = (4, 2, 16, 32) |
| causal_mask = torch.ones(16, 16, dtype=torch.bool, device=self.device).tril( |
| diagonal=0 |
| ) |
| attn_mask = torch.randn((16, 16), dtype=torch.float, device=self.device) |
| args = [ |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| torch.randn(tensor_shape, device=self.device), |
| causal_mask, |
| attn_mask, |
| ] |
| self._check_common( |
| dot_prod_attention, |
| args1=args, |
| contains=False, |
| has_dropout=True, |
| check_train=False, |
| ) |
| |
| |
| if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION: |
| |
| class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate): |
| device = "cuda" |
| test_sdpa_rewriter_1_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 |
| ) |
| test_sdpa_rewriter_1_freezing = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing |
| ) |
| test_insignificant_strides = ( |
| TestSDPAPatternRewriterTemplate._test_insignificant_strides |
| ) |
| test_pattern_fails_with_reuse_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse |
| ) |
| test_sdpa_rewriter_2_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 |
| ) |
| test_sdpa_rewriter_3_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3 |
| ) |
| test_sdpa_rewriter_4_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4 |
| ) |
| test_sdpa_rewriter_5_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 |
| ) |
| test_sdpa_rewriter_6_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6 |
| ) |
| test_sdpa_rewriter_7_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7 |
| ) |
| test_sdpa_rewriter_8_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8 |
| ) |
| test_sdpa_rewriter_9_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9 |
| ) |
| test_sdpa_rewriter_10_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10 |
| ) |
| test_pattern_fails_with_tensor_factor_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor |
| ) |
| test_pattern_fails_with_unsupported_mask_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask |
| ) |
| test_sdpa_rewriter_11_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 |
| ) |
| test_sdpa_rewriter_12_cuda = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 |
| ) |
| test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 |
| test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 |
| test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 |
| test_sdpa_rewriter_13_cuda = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half |
| ) |
| test_sdpa_rewriter_14_cuda = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 |
| ) |
| test_sdpa_rewriter_15_cuda = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 |
| ) |
| test_sdpa_rewriter_17_cuda = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 |
| ) |
| test_sdpa_rewriter_19_cuda = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 |
| ) |
| |
| class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests): |
| use_static_shapes = False |
| |
| |
| if HAS_CPU: |
| |
| class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): |
| device = "cpu" |
| test_sdpa_rewriter_1_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 |
| test_pattern_fails_with_reuse_cpu = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse |
| ) |
| test_sdpa_rewriter_2_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 |
| test_sdpa_rewriter_5_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 |
| test_pattern_fails_with_tensor_factor_cpu = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor |
| ) |
| test_pattern_fails_with_unsupported_mask_cpu = ( |
| TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask |
| ) |
| test_sdpa_rewriter_11_cpu = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 |
| ) |
| test_sdpa_rewriter_12_cpu = ( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 |
| ) |
| test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 |
| test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 |
| test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 |
| test_sdpa_rewriter_13_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32 |
| ) |
| test_sdpa_rewriter_14_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 |
| ) |
| test_sdpa_rewriter_15_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 |
| ) |
| test_sdpa_rewriter_16_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16 |
| ) |
| test_sdpa_rewriter_16_fp32_mask_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16_fp32_mask |
| ) |
| test_sdpa_rewriter_17_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 |
| ) |
| test_sdpa_rewriter_18_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_18 |
| ) |
| test_sdpa_rewriter_19_cpu = functools.partialmethod( |
| TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 |
| ) |
| |
| class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): |
| use_static_shapes = False |
| |
| |
| if __name__ == "__main__": |
| if IS_LINUX: |
| run_tests() |