blob: 2f745e157b280deec7fecdf9f0cb8fbce59d98c6 [file] [log] [blame]
import torch
import numpy as np
import random
import torch.utils.benchmark as benchmark
from torch.profiler import profile, record_function, ProfilerActivity
class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
super().__init__()
self.in_proj_weight = in_proj_weight
self.in_proj_bias = in_proj_bias
self.out_proj = out_proj
self.num_heads = num_heads
def forward(self, query, key, value, mask):
if not (query is key and key is value):
raise NotImplementedError(
"query, key and value must be the same Tensor for now."
)
if mask is not None:
raise NotImplementedError("mask is currently not supported.")
query_projected = torch.nn.functional.linear(
query, self.in_proj_weight, self.in_proj_bias
)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
attn, _ = torch.nn.functional._scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
need_attn_weights=False,
is_causal=False,
)
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
# Match return signature of nn.MHA
return self.out_proj(attn)
def build_composite_mha_from_nn_mha(pt):
assert pt._qkv_same_embed_dim
in_proj_weight = pt.in_proj_weight
assert in_proj_weight is not None
assert pt.batch_first
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
def forw_back(model, input, upward):
output = model(*input)
output.backward(upward)
# Context manger not working in timer
def forw_back_fused(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
output = model(*input)
output.backward(upward)
def forw_back_eager(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
output = model(*input)
output.backward(upward)
def run_timing(
min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype
):
dropout_p = 0.0
mask = None
pt = torch.nn.MultiheadAttention(
embed_dim=embed_dimension,
num_heads=num_heads,
batch_first=True,
dropout=dropout_p,
)
npt = pt.cuda().to(dtype)
cpt = build_composite_mha_from_nn_mha(npt)
x = torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device="cuda",
requires_grad=True,
)
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
rand_fused_upward = cpt(x, x, x, mask).clone().detach()
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
rand_eager_upward = cpt(x, x, x, mask).clone().detach()
t0 = benchmark.Timer(
stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)",
globals={
"forw_back_fused": forw_back_fused,
"cpt": cpt,
"x": x,
"rand_fused_upward": rand_fused_upward,
"mask": mask,
},
label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
t1 = benchmark.Timer(
stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)",
globals={
"forw_back_eager": forw_back_eager,
"cpt": cpt,
"x": x,
"rand_eager_upward": rand_eager_upward,
"mask": mask,
},
label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
m0 = t0.blocked_autorange(min_run_time=min_run_time)
m1 = t1.blocked_autorange(min_run_time=min_run_time)
print(m0)
print(m1)
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print("Profile for Fused".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_fused_upward)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
print("Profile for eager".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_eager_upward)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
min_run_time = 10
batch_size = 64
num_heads = 32
max_seq_len = 256
embed_dim = 1024
dtype = torch.bfloat16
print(
f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} "
f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}"
)
run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype)
if __name__ == "__main__":
main()