blob: 0fe95506a22eaa6c5e5dcf28bd9120a4d7c81f5c [file] [log] [blame]
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Callable, List, Tuple
import torch
import torch.utils.benchmark as benchmark
from tabulate import tabulate
from torch.nn.functional import scaled_dot_product_attention
from tqdm import tqdm
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
kv_seq_len: int
embed_dim: int
is_causal: bool
dtype: torch.dtype
device: torch.device = torch.device("cuda")
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
return dict_obj
@dataclass(frozen=True)
class ExperimentResults:
forward_time: float
backward_time: float
def asdict(self):
return asdict(self)
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def asdict(self):
dict1 = asdict(self.config)
dict2 = asdict(self.results)
return {**dict1, **dict2}
def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.randn(
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
k = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
v = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
return q, k, v
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
q, k, v = get_input(config)
is_causal = config.is_causal
forward_time = benchmark_torch_function_in_microseconds(
scaled_dot_product_attention,
q,
k,
v,
is_causal=is_causal,
attn_mask=None,
)
out_torch = scaled_dot_product_attention(
q, k, v, is_causal=is_causal, attn_mask=None
)
dOut = torch.randn_like(out_torch)
backward_time = benchmark_torch_function_in_microseconds(
out_torch.backward, dOut, retain_graph=True
)
return ExperimentResults(
forward_time=forward_time,
backward_time=backward_time,
)
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [
1,
8,
]
num_heads = [16]
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
embed_dims = [2048]
dtypes = [
torch.bfloat16,
]
is_causal = [True, False]
all_configs = []
for (
bsz,
heads,
(q_seq_len, kv_seq_len),
embed_dim,
causal,
dtype,
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
embed_dim=embed_dim,
is_causal=causal,
dtype=dtype,
)
)
return all_configs
def print_results(experiments: List[Experiment]):
table_data = defaultdict(list)
for experiment in experiments:
for key, value in experiment.asdict().items():
table_data[key].append(value)
del table_data["device"]
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
def main():
seed = 123
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()