blob: 863a040eceada4c1118debadbcf51384963a62e6 [file] [log] [blame]
# Owner(s): ["module: inductor"]
# flake8: noqa: B950
import functools
import string
from collections import namedtuple
from typing import Callable, Optional
from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch
import torch
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
from torch._inductor import metrics
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention.flex_attention import (
_causal,
_compose,
_create_empty_block_mask,
_generate_alibi_bias,
_identity,
_rel_bias,
_rel_causal,
BlockMask,
create_block_mask,
flex_attention,
)
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
from torch.utils._triton import has_triton
# Skip tests if Triton is not available
supported_platform = skipUnless(
torch.cuda.is_available()
and torch.version.hip is None
and has_triton()
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",
)
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
torch.set_float32_matmul_precision("high")
index = torch.ops.aten.index
def rmse(ref, res):
"""
Calculate root mean squared error
"""
return torch.sqrt(torch.mean(torch.square(ref - res)))
def create_attention(score_mod, block_mask):
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
def create_block_mask_test(score_mod, query, key):
block_mask = create_block_mask(
score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
)
return block_mask
test_dtypes = (
[torch.float16, torch.bfloat16, torch.float32]
if PLATFORM_SUPPORTS_BF16
else [torch.float16, torch.float32]
)
test_dtypes_fast = [torch.float16]
# --------- Useful score mod functions for testing ---------
def _inverse_causal(score, b, h, m, n):
return torch.where(m <= n, score, float("-inf"))
def _times_two(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * 2
def _squared(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * score
def _head_offset(dtype: torch.dtype):
"""Captured Buffer"""
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
return score * head_offset[h]
return score_mod
def _trig(score, b, h, m, n):
"""Joint graph needed for correctness"""
return torch.sin(torch.cos(score)) + torch.tan(b)
def _trig2(score, b, h, m, n):
"""Branching joint graph"""
cos_score = torch.cos(score)
sin_score = torch.sin(score)
z = cos_score * sin_score + torch.tan(b)
return z
test_score_mods = [
_identity,
_times_two,
_squared,
_causal,
_inverse_causal,
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
]
captured_buffers_map = {
"_head_offset": _head_offset,
}
B = 4
H = 8
S = 2048
D = 64
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
class TestFlexAttention(InductorTestCase):
def _check_equal(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
fudge_factor: float,
tensor_name: Optional[str] = None,
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
self.assertTrue(False, "Output/Grad with NaN")
if compiled_error > ref_error * fudge_factor:
name = tensor_name if tensor_name is not None else ""
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def _check_out_and_grad(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
q_gold: torch.Tensor,
q_ref: torch.Tensor,
q: torch.Tensor,
k_gold: torch.Tensor,
k_ref: torch.Tensor,
k: torch.Tensor,
v_gold: torch.Tensor,
v_ref: torch.Tensor,
v: torch.Tensor,
):
dtype = ref_out.dtype
with torch.no_grad():
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
# Check gradients
q_fudge_factor = 1.0 * fudge_factor
self._check_equal(
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
)
k_fudge_factor = 1.0 * fudge_factor
self._check_equal(
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
)
v_fudge_factor = 1.0 * fudge_factor
self._check_equal(
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
)
def run_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = H,
Q_S: int = S,
Q_D: int = D,
KV_B: int = B,
KV_H: int = H,
KV_S: int = S,
KV_D: int = D,
):
q = torch.randn(
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
)
k = torch.randn(
(KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
)
v = torch.randn(
(KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
block_mask = create_block_mask_test(score_mod, q, k)
sdpa_partial = create_attention(score_mod, block_mask)
compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda")
golden_out.backward(backward_grad.to(torch.float64))
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)
self._check_out_and_grad(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)
def run_test_with_call(
self,
sdpa_call: Callable,
dtype: torch.dtype = torch.float16,
Q_B: int = B,
Q_H: int = H,
Q_S: int = S,
Q_D: int = D,
KV_B: int = B,
KV_H: int = H,
KV_S: int = S,
KV_D: int = D,
):
q = torch.randn(
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
)
k = torch.randn(
(KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
)
v = torch.randn(
(KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
compiled_sdpa = torch.compile(sdpa_call)
golden_out = sdpa_call(q_gold, k_gold, v_gold)
ref_out = sdpa_call(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda")
golden_out.backward(backward_grad.to(torch.float64))
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)
self._check_out_and_grad(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)
def run_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1.backward(backward_grad1.to(torch.float64))
ref_out1.backward(backward_grad1)
# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2.backward(backward_grad2.to(torch.float64))
ref_out2.backward(backward_grad2)
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
torch._dynamo.reset()
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
compiled_out1.backward(backward_grad1)
self._check_out_and_grad(
golden_out1,
ref_out1,
compiled_out1,
q1_gold,
q1_ref,
q1,
k1_gold,
k1_ref,
k1,
v1_gold,
v1_ref,
v1,
)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
compiled_out2.backward(backward_grad2)
self._check_out_and_grad(
golden_out2,
ref_out2,
compiled_out2,
q2_gold,
q2_ref,
q2,
k2_gold,
k2_ref,
k2,
v2_gold,
v2_ref,
v2,
)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
def run_automatic_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)
# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)
# The third eager batch, shape (B * 4, H, S / 4, D)
B = int(B * 2)
S = int(S / 2)
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out3 = sdpa_partial(
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
)
ref_out3 = sdpa_partial(q3, k3, v3)
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure:
# 1, the first batch is compiled with static shape
# 2, the second batch is compiled with dynamic shape
# 3, no re-compilation in the third batch
torch._dynamo.reset()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# The first batch.
compiled_sdpa = torch.compile(sdpa_partial)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# The second batch (automatic dynamic).
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
# The third batch (no re-compilation).
compiled_out3 = compiled_sdpa(q3, k3, v3)
self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)
@expectedFailure # TODO: supports block sparsity with dynamic shapes
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
self.run_dynamic_test(score_mod, dtype)
@expectedFailure # TODO: supports block sparsity with dynamic shapes
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_automatic_dynamic(
self, dtype: torch.dtype, score_mod: Callable
):
self.run_automatic_dynamic_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_different_seqlen(
self, dtype: torch.dtype, score_mod: Callable
):
self.run_test(
score_mod,
dtype,
B,
H,
S // 2, # Seqlen of Q is different from seqlen of K/V
D,
B,
H,
S,
D,
)
test_strides = [
((H * S * D, S * D, D, 1), 997), # offset
((H * D, D, B * H * D, 1), 499), # transposed dimensions
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
(
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
293,
), # additional buffer on one dim
(
(1, D, (B + 1) * (H + 1) * D, 1),
97,
), # additional buffer on multiple dim + shared dimension
]
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize(
"q_s", test_strides[:2]
) # TODO: fix layout for query braodcasting
@common_utils.parametrize(
"k_s,v_s",
[
(test_strides[0], test_strides[0]),
(test_strides[0], test_strides[1]),
(test_strides[2], test_strides[3]),
(test_strides[3], test_strides[1]),
# (test_strides[2], test_strides[4]), # TODO: Doesn't work for
# broadcasting reasons i think
],
)
@common_utils.parametrize("do_s", test_strides[:3])
def test_strided_inputs(self, dtype: torch.dtype, q_s, k_s, v_s, do_s):
q1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
k1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
v1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
do1 = torch.randn((B * H * S * D * 2), dtype=dtype, device="cuda")
q_shape = (B, H, S // 2, D)
k_shape = (B, H, S, D)
v_shape = (B, H, S, D)
do_shape = (B, H, S // 2, D)
def coerce_to_strides(val, shape, strides):
strides, offset = strides
val_max = [x * (y - 1) for x, y in zip(strides, shape)]
assert sum(val_max) + offset < B * H * S * D * 2
assert strides[-1] == 1
return torch.as_strided(val, shape, strides, offset).requires_grad_(True)
q = coerce_to_strides(q1, q_shape, q_s)
k = coerce_to_strides(k1, k_shape, k_s)
v = coerce_to_strides(v1, v_shape, v_s)
do = coerce_to_strides(do1, do_shape, do_s)
block_mask = _create_empty_block_mask(q, k, v)
sdpa_partial = create_attention(
score_mod=_generate_alibi_bias(8), block_mask=block_mask
)
compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(
ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
)
ref_out.backward(do)
ref_grads = [q.grad, k.grad, v.grad]
q.grad = None
k.grad = None
v.grad = None
compiled_out.backward(do)
compiled_grads = [q.grad, k.grad, v.grad]
q.grad = None
k.grad = None
v.grad = None
torch.testing.assert_close(
compiled_grads[0], ref_grads[0], atol=tolerance.atol, rtol=tolerance.rtol
)
torch.testing.assert_close(
compiled_grads[1], ref_grads[1], atol=tolerance.atol, rtol=tolerance.rtol
)
torch.testing.assert_close(
compiled_grads[2], ref_grads[2], atol=tolerance.atol, rtol=tolerance.rtol
)
@supported_platform
def test_doc_mask_sparse(self):
document_id = torch.zeros(S, dtype=torch.int, device="cuda")
for i in range(0, S, 256):
document_id[i : i + 256] = i // 256
def document_masking_causal(score, b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = document_id[q_idx] == document_id[kv_idx]
return torch.where(causal_mask & document_mask, score, -float("inf"))
self.run_test(document_masking_causal, torch.float16)
@supported_platform
def test_index_multiple(self):
bias = torch.randn(B, S, device="cuda")
def index_multiple(score, b, h, q_idx, kv_idx):
return score + bias[b][q_idx]
self.run_test(index_multiple, torch.float16)
@supported_platform
def test_index_weird1(self):
bias = torch.randn(4, B, H, S, device="cuda")
def index_weird1(score, b, h, q_idx, kv_idx):
return score + bias[0][b, h][q_idx]
self.run_test(index_weird1, torch.float16)
@supported_platform
def test_index_weird2(self):
bias = torch.randn(B, H, 4, S, device="cuda")
which_bias = torch.tensor(0, device="cuda")
def index_weird2(score, b, h, q_idx, kv_idx):
return score + bias[b][h][which_bias, q_idx]
self.run_test(index_weird2, torch.float16)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
def score_mod(score, b, h, q, kv):
return torch.where(kv % 2 == 0, score, float("-inf"))
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_function_composition(self, dtype: torch.dtype):
def score_mod_1(score, b, h, m, n):
return score + (m - n)
def score_mod_2(score, b, h, m, n):
return torch.where(m <= n, score, float("-inf"))
composed_score_mod = _compose(score_mod_1, score_mod_2)
self.run_test(composed_score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers(self, dtype: torch.dtype):
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
return score + head_offset[h]
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers_all_dims(self, dtype: torch.dtype):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def all_bias(score, batch, head, token_q, token_kv):
score = score + tok_scale[token_q]
score = score + batch_scale[batch]
score = score + head_scale[head]
return score
self.run_test(all_bias, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_seq_masking(self, dtype):
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
seq_idx[S // 2 :] = 1
def seq_mask_mod(score, b, h, q, kv):
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
self.run_test(seq_mask_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_only(self, dtype):
bias = torch.randn(S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + bias[q, kv]
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_batch(self, dtype):
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + bias[b, q, kv]
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_head_seq_batch(self, dtype):
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + bias[b, h, q, kv]
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_rel_bias(self, dtype):
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + rel_bias[(q - kv) + S]
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_dependent_causal_bidirectional(self, dtype):
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
def bias_mod(score, b, h, q, kv):
causal_attention = q >= kv
cur_num_bidirectional = num_bidirectional[b]
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
kv <= cur_num_bidirectional
)
return torch.where(
bidirectional_attention_on_video | causal_attention,
score,
-float("inf"),
)
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_natten_2d(self, dtype):
H = 32
W = S // H
WINDOW = 3
assert W * H == S
def get_x_y(idx):
# This should be a floor divide, but we don't support that properly
return idx / W, idx % W
def natten_mask(score, b, h, q, kv):
q_x, q_y = get_x_y(q)
kv_x, kv_y = get_x_y(kv)
return torch.where(
((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW),
score,
float("-inf"),
)
self.run_test(natten_mask, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_subgraph_respect_decompostion(self, dtype):
from torch._decomp import core_aten_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
def score_mod_func(score, b, h, q, kv):
return score - q // (1 + kv)
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
# floor_div is not decomposed in decompostion_table is empty
attention = functools.partial(flex_attention, score_mod=score_mod_func)
gm = make_fx(attention, decomposition_table={})(query, key, value)
self.assertExpectedInline(
gm.sdpa_score0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None
sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None
return sub""",
)
# floor_div is decomposed for core_aten_decompositions
gm = make_fx(attention, decomposition_table=core_aten_decompositions())(
query, key, value
)
self.assertExpectedInline(
gm.sdpa_score0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None
sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None
return sub""",
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_silu_on_score(self, dtype):
def silu_score(score, b, h, q, kv):
return torch.nn.functional.silu(score)
self.run_test(silu_score, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_padded_dense_causal(self, dtype):
seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
def create_padded_dense_wrapper(orig_score_mod):
def njt_score_mod(qk, b, h, q, kv):
return torch.where(
qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
)
return njt_score_mod
causal_njt = create_padded_dense_wrapper(_causal)
self.run_test(causal_njt, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_captured_scale(self, dtype):
scale = torch.ones((), device="cuda", dtype=torch.int32)
def score_mod_scale(qk, b, h, q, kv):
return qk + scale
self.run_test(score_mod_scale, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_recompile_changed_score_mod(self, dtype):
scale = torch.ones((), device="cuda", dtype=torch.int32)
ADD = True
def score_mod_scale(qk, b, h, q, kv):
if ADD:
return qk + scale
else:
return qk * scale
self.run_test(score_mod_scale, dtype)
ADD = False
self.run_test(score_mod_scale, dtype)
@supported_platform
@expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_captured_reduction(self, dtype):
scale = torch.randn((B, 8), device="cuda")
def score_mod_scale(qk, b, h, q, kv):
return qk + scale[b].sum(dim=-1)
self.run_test(score_mod_scale, dtype)
@supported_platform
def test_multiple_score_mod_calls(self):
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
for _ in range(2)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
for _ in range(2)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
def f(q, k1, k2, v1, v2):
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1)
return flex_attention(q2, k2, v2, score_mod=scoremod_2)
out = f(query, *keys, *values)
out2 = torch.compile(f)(query, *keys, *values)
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
@supported_platform
def test_multiple_score_mod_calls2(self):
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
keys = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
for _ in range(3)
]
values = [
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
for _ in range(3)
]
def scoremod_1(qk, b, h, q, kv):
return qk + (q - kv)
def scoremod_2(qk, b, h, q, kv):
return torch.where(q >= kv, qk, -float("inf"))
attention1 = functools.partial(flex_attention, score_mod=scoremod_1)
def f(q, k1, k2, k3, v1, v2, v3):
q2 = attention1(q, k1, v1)
q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2)
return flex_attention(q3, k3, v3, score_mod=scoremod_1)
out = f(query, *keys, *values)
out2 = torch.compile(f)(query, *keys, *values)
self.assertTrue((out - out2).abs().mean() < 1e-2)
@supported_platform
def test_inputs_are_realized(self):
def f(q, k, v):
x = torch.randn(1024, device="cuda")
x = x * 2
def func(qk, b, h, q, kv):
return qk + x[q]
return flex_attention(q.sin(), k, v, score_mod=func).cos()
q, k, v = (
torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
for _ in range(3)
)
ref = f(q, k, v)
out = torch.compile(f)(q, k, v)
self.assertTrue((ref - out).abs().mean() < 1e-2)
gradOut = torch.randn_like(q)
ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut)
out_grads = torch.autograd.grad(out, (q, k, v), gradOut)
for ref, out in zip(ref_grads, out_grads):
self.assertTrue((ref - out).abs().mean() < 1e-2)
@supported_platform
def test_make_block_mask(self):
def causal_mask(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
block_mask_a = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=True)
block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=False)
self.assertEqual(block_mask_a.kv_num_blocks, block_mask_b.kv_num_blocks)
self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices)
self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
@supported_platform
def test_epilogue_fused(self):
@torch.compile
def f(q, k, v):
out = flex_attention(q, k, v)
return out.cos()
q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
metrics.reset()
_, code = run_and_get_code(f, q, k, v)
# TODO: attention output is not being DCE'd
fc = FileCheck()
fc.check("triton_tem_fused") # template call
fc.check_not("poi_fused_cos") # No cos pointwise operation
fc.run(code[0])
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
num_accesses = 4 # q, k, v reads, one output.
# TODO: Get rid of this fudge factor
# We need this fudge factor for now, since
# 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow)
# 2. We also write the extraneous logsumexp
num_accesses += 2
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_njt_causal(self, dtype):
offsets = torch.tensor(
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
)
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
for idx in range(len(offsets) - 1):
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
def njt_score_mod(qk, b, h, q, kv):
q_nested = q - offsets[seq_idx[q]]
kv_nested = kv - offsets[seq_idx[kv]]
return orig_score_mod(qk, b, h, q_nested, kv_nested)
return njt_score_mod
causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
self.run_test(causal_njt, dtype)
@supported_platform
def test_mixed_dtypes_fails(self):
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
with self.assertRaisesRegex(
ValueError, "Expected query, key, and value to have the same dtype"
):
flex_attention(query, key, value, _identity)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune(self):
def score_mod(score, b, h, m, n):
return score * 2
self.run_test(score_mod)
@supported_platform
@skip("TODO: Figure out why this is erroring")
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def bias_mod(score, batch, head, token_q, token_kv):
score = score + tok_scale[token_q]
score = score + batch_scale[batch]
score = score + head_scale[head]
return score
self.run_test(bias_mod)
@supported_platform
def test_autograd_function_in_score_mod(self):
class ApplyMask(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(a, mask):
return torch.where(mask, a, -float("inf"))
@staticmethod
def setup_context(ctx, inputs, output):
_, mask = inputs
ctx.mark_non_differentiable(mask)
pass
@staticmethod
def backward(ctx, i):
return i, None
def score_mod(score, b, h, q, kv):
return ApplyMask.apply(score, q <= kv)
func = torch.compile(flex_attention, fullgraph=True)
q, k, v = (
torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
for _ in range(3)
)
# Just checking that it runs
func(q, k, v)
# expectedFailure
# This doesn't work due to vmap + autograd.Function + torch.compile not composing
# self.run_test(score_mod)
@supported_platform
def test_causal_block(self):
def mask_mod(b, h, q, kv):
return q >= kv
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
attention = functools.partial(flex_attention, block_mask=block_mask)
self.run_test_with_call(attention)
@supported_platform
def test_custom_block_mask_generator(self):
def mask_mod(b, h, q, kv):
return q >= kv
auto_mask = create_block_mask(mask_mod, 1, 1, S, S)
BLOCK_SIZE = 128
def causal_constructor(S):
num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
S // BLOCK_SIZE, S // BLOCK_SIZE
)
num_blocks = num_blocks[None, None, :]
indices = indices[None, None, :]
return BlockMask(
num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=mask_mod
)
manual_mask = causal_constructor(S)
self.assertEqual(auto_mask.to_dense(), manual_mask.to_dense())
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity, _causal])
def test_logsumexp_correctness(self, dtype, score_mod):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
dtype=dtype,
device="cuda",
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_mask(q, k, v)
@torch.compile
def sdpa_hop(q, k, v, score_mod, block_mask):
return flex_attention_hop(
q,
k,
v,
score_mod,
block_mask.as_tuple(),
1.0,
)
@torch.compile(backend="aot_eager")
def eager_sdpa_hop(q, k, v, score_mod, block_mask):
"""The main entrypoint for FlexAttention doesnt return LSE.
Besides dropping LSE it also ensures that the hop is compiled with aot-eager
backend. We need to replicate this.
"""
return flex_attention_hop(q, k, v, score_mod, block_mask.as_tuple(), 1.0)
ref_out, ref_lse = eager_sdpa_hop(
q.to(torch.float64),
k.to(torch.float64),
v.to(torch.float64),
score_mod,
block_mask,
)
compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod, block_mask)
# Comparing LSE for the ref and the compiled version
# The compiled uses a change of base trick to more efficiently compute the LSE
# this means that the base for the LSE computed by ref is e while for the compiled
# version it is 2. To compare we use the change of base formula
# log_2(x_compiled) = log_e(x_ref) * log_2(e) where
# x_ref = sum(_i e^(scores[i]))
# x_compiled = sum(_i 2^(log2(e) * scores[i]))
self.assertTrue(ref_lse.dtype == torch.float64)
self.assertTrue(compiled_lse.dtype == torch.float32)
ref_lse = ref_lse * torch.log2(torch.tensor(torch.e))
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
torch.testing.assert_close(
ref_out.to(dtype=torch.float32),
compiled_out.to(dtype=torch.float32),
atol=tolerance.atol,
rtol=tolerance.rtol,
)
torch.testing.assert_close(
ref_lse.to(dtype=torch.float32),
compiled_lse.to(dtype=torch.float32),
atol=tolerance.atol,
rtol=tolerance.rtol,
)
@supported_platform
def test_logsumexp_only_return(self):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_mask(q, k, v)
@torch.compile
def func(q, k, v, score_mod, block_mask):
_, lse = flex_attention_hop(
q,
k,
v,
score_mod,
block_mask.as_tuple(),
scale=1.0,
)
lse_2 = lse * 2
return lse_2
_, code = run_and_get_code(func, q, k, v, _identity, block_mask)
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])
@supported_platform
def test_logsumexp_is_not_fused(self):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_mask(q, k, v)
@torch.compile
def func(q, k, v, score_mod, block_mask):
out, lse = flex_attention_hop(
q,
k,
v,
score_mod,
block_mask.as_tuple(),
1.0,
)
lse_2 = lse * 2
return out, lse_2
_, code = run_and_get_code(func, q, k, v, _identity, block_mask)
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])
@supported_platform
@common_utils.parametrize(
"score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
)
def test_aot_eager_gradcheck(self, score_mod):
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
func = torch.compile(flex_attention, backend="aot_eager", fullgraph=True)
self.assertTrue(
torch.autograd.gradcheck(
func, (query, key, value, score_mod), raise_exception=True
)
)
@supported_platform
@common_utils.parametrize("score_mod_name", ["_head_offset"])
@common_utils.parametrize("mode", ["eager", "aot_eager"])
def test_captured_score_mod_aot_eager_gradcheck(
self, score_mod_name: str, mode: str
):
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
score_mod = captured_buffers_map[score_mod_name](torch.float64)
self.assertTrue(
torch.autograd.gradcheck(
func, (query, key, value, score_mod), raise_exception=True
)
)
@supported_platform
def test_comparison_vs_sdpa(self):
def causal(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal, 1, 1, 2048, 2048)
no_sparse_flex = functools.partial(flex_attention, score_mod=causal)
score_mod_sparse_flex = functools.partial(
flex_attention,
score_mod=causal,
block_mask=create_block_mask(causal, 1, 1, 2048, 2048),
)
mask_mod_sparse_flex = functools.partial(
flex_attention, block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048)
)
for attention_call in [
no_sparse_flex,
score_mod_sparse_flex,
mask_mod_sparse_flex,
]:
inputs = [
torch.randn(
2,
2,
2048,
64,
device="cuda",
dtype=torch.float16,
requires_grad=True,
)
for _ in range(3)
]
gradOut = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float16)
out_ref = torch.nn.functional.scaled_dot_product_attention(
*inputs, is_causal=True
)
out_ref.backward(gradOut)
inputs_flex = [i.detach().clone().requires_grad_(True) for i in inputs]
out_flex = torch.compile(attention_call)(*inputs_flex)
out_flex.backward(gradOut)
inputs_golden = [
i.detach().clone().to(dtype=torch.float64).requires_grad_(True)
for i in inputs
]
out_golden = torch.nn.functional.scaled_dot_product_attention(
*inputs_golden, is_causal=True
)
out_golden.backward(gradOut.to(dtype=torch.float64))
for ref, flex, golden in [
(out_ref, out_flex, out_golden),
(inputs[0].grad, inputs_flex[0].grad, inputs_golden[0].grad),
(inputs[1].grad, inputs_flex[1].grad, inputs_golden[1].grad),
(inputs[2].grad, inputs_flex[2].grad, inputs_golden[2].grad),
]:
ref_error = rmse(ref, golden)
flex_error = rmse(flex, golden)
# Note: This has been carefully tested that FlexAttention is within
# 20% of the average error of SDPA! Do not bump this tolerance
# unless you are absolutely sure you are not worsening the accuracy
# of FlexAttention!
self.assertTrue(
ref_error * 1.2 > flex_error,
f"Ref error: {ref_error}, Flex Error: {flex_error}",
)
@supported_platform
def test_block_mask_attributes(self):
offset = torch.zeros(8, device="cuda")
def causal(score, b, h, q, kv):
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
block_mask = create_block_mask(causal, 4, 2, 2048, 2048)
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
self.assertEqual(block_mask.sparsity(), 46.875)
self.assertEqual(block_mask[0].sparsity(), 46.875)
self.assertEqual(block_mask[1, 0].sparsity(), 46.875)
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
offset = torch.arange(8, device="cuda")
block_mask = create_block_mask(causal, 8, 1, 2048, 2048)
self.assertEqual(block_mask.sparsity(), 29.1015625)
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
@supported_platform
def test_block_mask_viz(self):
def causal(score, b, h, q, kv):
return torch.where(q >= kv, score, -float("inf"))
block_mask = create_block_mask(causal, 1, 1, 2048, 2048)
def replace_non_printable(s):
def replace(c):
if c not in string.printable:
return "@"
elif c == " ":
return "s"
return c
return "".join(replace(c) for c in s)
self.assertExpectedInline(
replace_non_printable(str(block_mask)),
"""\
BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
(0,s0)
@@ssssssssssssssssssssssssssssss
@@@@ssssssssssssssssssssssssssss
@@@@@@ssssssssssssssssssssssssss
@@@@@@@@ssssssssssssssssssssssss
@@@@@@@@@@ssssssssssssssssssssss
@@@@@@@@@@@@ssssssssssssssssssss
@@@@@@@@@@@@@@ssssssssssssssssss
@@@@@@@@@@@@@@@@ssssssssssssssss
@@@@@@@@@@@@@@@@@@ssssssssssssss
@@@@@@@@@@@@@@@@@@@@ssssssssssss
@@@@@@@@@@@@@@@@@@@@@@ssssssssss
@@@@@@@@@@@@@@@@@@@@@@@@ssssssss
@@@@@@@@@@@@@@@@@@@@@@@@@@ssssss
@@@@@@@@@@@@@@@@@@@@@@@@@@@@ssss
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ss
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
)""",
)
offset = torch.arange(8, device="cuda")
def causal_offset(score, b, h, q, kv):
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
block_mask = create_block_mask(causal_offset, 8, 1, 2048, 2048)
str_block_mask = str(block_mask)
self.assertTrue("sparsity=29.10" in str_block_mask)
@supported_platform
def test_fw_bw_graph_correctness(self):
cnt = CompileCounterWithBackend("aot_eager")
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
def causal(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal, 1, 1, 128, 128)
func = torch.compile(flex_attention, backend=cnt, fullgraph=True)
out = func(query, key, value, _squared, block_mask=block_mask)
out.sum().backward()
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
graph = cnt.graphs[0]
norm_graph = normalize_gm(graph.print_readable(print_output=False))
self.assertExpectedInline(
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_args_0_: "f64[2, 2, 128, 4]", L_args_1_: "f64[2, 2, 128, 4]", L_args_2_: "f64[2, 2, 128, 4]", L_kwargs_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_kwargs_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_kwargs_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_kwargs_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_kwargs_block_mask_q_num_blocks: "i32[1, 1, 1]", L_kwargs_block_mask_q_indices: "i32[1, 1, 1, 1]", L_kwargs_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_kwargs_block_mask_full_q_indices: "i32[1, 1, 1, 1]"):
l_args_0_ = L_args_0_
l_args_1_ = L_args_1_
l_args_2_ = L_args_2_
l_kwargs_block_mask_kv_num_blocks = L_kwargs_block_mask_kv_num_blocks
l_kwargs_block_mask_kv_indices = L_kwargs_block_mask_kv_indices
l_kwargs_block_mask_full_kv_num_blocks = L_kwargs_block_mask_full_kv_num_blocks
l_kwargs_block_mask_full_kv_indices = L_kwargs_block_mask_full_kv_indices
l_kwargs_block_mask_q_num_blocks = L_kwargs_block_mask_q_num_blocks
l_kwargs_block_mask_q_indices = L_kwargs_block_mask_q_indices
l_kwargs_block_mask_full_q_num_blocks = L_kwargs_block_mask_full_q_num_blocks
l_kwargs_block_mask_full_q_indices = L_kwargs_block_mask_full_q_indices
child_1: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_2: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child: "f64[]" = l_args_0_.new_empty([], requires_grad = True)
score_mod_0 = self.score_mod_0
child_5: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_6: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_7: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
child_8: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
mask_fn_0 = self.mask_fn_0
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, score_mod_0, (l_kwargs_block_mask_kv_num_blocks, l_kwargs_block_mask_kv_indices, l_kwargs_block_mask_full_kv_num_blocks, l_kwargs_block_mask_full_kv_indices, l_kwargs_block_mask_q_num_blocks, l_kwargs_block_mask_q_indices, l_kwargs_block_mask_full_q_num_blocks, l_kwargs_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, (), ()); l_args_0_ = l_args_1_ = l_args_2_ = score_mod_0 = l_kwargs_block_mask_kv_num_blocks = l_kwargs_block_mask_kv_indices = l_kwargs_block_mask_full_kv_num_blocks = l_kwargs_block_mask_full_kv_indices = l_kwargs_block_mask_q_num_blocks = l_kwargs_block_mask_q_indices = l_kwargs_block_mask_full_q_num_blocks = l_kwargs_block_mask_full_q_indices = mask_fn_0 = None
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
return (out,)
class GraphModule(torch.nn.Module):
def forward(self, child: "f64[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]", child_4: "i32[]"):
mul: "f64[]" = child * child; child = None
return mul
class GraphModule(torch.nn.Module):
def forward(self, child_5: "i32[]", child_6: "i32[]", child_7: "i32[]", child_8: "i32[]"):
ge: "b8[]" = child_7 >= child_8; child_7 = child_8 = None
return ge
""", # noqa: B950
)
# Save the AOT graphs
aot_graphs = []
from torch._inductor import compile_fx
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
aot_graphs.append(graph)
return graph
backend = functools.partial(
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
)
func = torch.compile(func, backend=backend, fullgraph=True)
out = func(query, key, value, _squared)
out.sum().backward()
joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))
self.assertExpectedInline(
joint_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
fw_graph = self.fw_graph
joint_graph = self.joint_graph
mask_graph = self.mask_graph
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, fw_graph, joint_graph, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 128, 128, mask_graph), 0.5, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
return [getitem_4, getitem_5, getitem_6]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return mul
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1)
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
return [add, None, None, None, None]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
return full
""", # noqa: B950
)
@supported_platform
def test_nyi_for_non_divisible_seq_lens(self):
with self.assertRaisesRegex(
NotImplementedError, "NYI: L must be a multiple of 128"
):
flex_attention(
torch.randn((2, 3, 4)),
torch.randn((2, 10, 5)),
torch.randn((2, 10, 5)),
score_mod=_identity,
)
with self.assertRaisesRegex(
NotImplementedError, "NYI: L must be a multiple of 128"
):
compiled_flex = torch.compile(flex_attention)
compiled_flex(
torch.randn((2, 3, 4)),
torch.randn((2, 10, 5)),
torch.randn((2, 10, 5)),
score_mod=_identity,
)
common_utils.instantiate_parametrized_tests(TestFlexAttention)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()