blob: fbe57b540dc2673e9f1867e6f3ec7d0dab2316c5 [file] [log] [blame]
import itertools
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from tabulate import tabulate
from torch.nn.attention.bias import CausalBias, CausalVariant
from torch.nn.parameter import Parameter
from tqdm import tqdm
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
embed_dim: int
dtype: torch.dtype
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
return dict_obj
@dataclass(frozen=True)
class ExperimentResults:
materialized_mask_time: float
attn_mask_subclass_time: float
def get_entries(self) -> List:
return [
f"{self.materialized_mask_time:2f}",
f"{self.attn_mask_subclass_time:2f}",
]
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
return self.config.get_entries() + self.results.get_entries()
def generate_inputs(
batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device
):
q_shape = (batch_size, q_sequence_length, embed_dim)
kv_shape = (batch_size, kv_sequence_length, embed_dim)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
return make_q(), make_kv(), make_kv()
class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, embed_dim, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.head_dim = embed_dim // num_heads
self.embed_dim = embed_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.num_heads = num_heads
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Union[torch.Tensor, CausalBias],
):
query_projected = F.linear(query, self.q_proj_weight)
key_projected = F.linear(key, self.k_proj_weight)
value_projected = F.linear(value, self.v_proj_weight)
query = query.view(
query_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
key = key.view(
key_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
value = value.view(
value_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
attn = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
dropout_p=0.0,
)
attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim)
# Match return signature of nn.MHA
return F.linear(attn, self.out_proj)
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.constant_(self.out_proj, 0.0)
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
device = torch.device("cuda")
composite_mha = CompositeMHA(
config.num_heads, config.embed_dim, device, config.dtype
)
composite_mha.reset_parameters()
query, key, value = generate_inputs(
config.batch_size,
config.q_seq_len,
config.k_seq_len,
config.embed_dim,
config.dtype,
device,
)
attn_mask = CausalBias(
CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len
)
attn_mask_tensor = attn_mask._materialize(device)
materialized_mask_time = benchmark_torch_function_in_microseconds(
composite_mha, query, key, value, attn_mask_tensor
)
attn_mask_subclass_time = benchmark_torch_function_in_microseconds(
composite_mha, query, key, value, attn_mask
)
torch.testing.assert_close(
composite_mha(query, key, value, attn_mask_tensor),
composite_mha(query, key, value, attn_mask),
)
return ExperimentResults(
materialized_mask_time=materialized_mask_time,
attn_mask_subclass_time=attn_mask_subclass_time,
)
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16, 128]
num_heads = [16, 32]
q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
embed_dims = [2048, 4096]
dtypes = [
torch.bfloat16,
]
all_configs = []
for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
embed_dim=embed_dim,
dtype=dtype,
)
)
return all_configs
def calculate_speedup(results: ExperimentResults) -> float:
return results.materialized_mask_time / results.attn_mask_subclass_time
def print_results(results: List[Experiment]):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
min_speedup_index = np.argmin(speedups)
# Get the config dictionaries
max_config_dict = results[max_speedup_index].config.asdict()
min_config_dict = results[min_speedup_index].config.asdict()
# Create table data
table_data = [
{
"Type": "Average",
"Speedup": np.mean(speedups),
**dict.fromkeys(max_config_dict),
},
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
]
# Print table
print(tabulate(table_data, headers="keys", tablefmt="pretty"))
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
# Run one timing experiment comparing nn_mha vs composite_mha
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()