blob: 311b23395efdb27170c0fd4a17d7320667aa49c5 [file] [log] [blame]
import argparse
import torch
def bench(nt_a, nt_b, niter):
# Warmup
nt_c = nt_a.bmm(nt_b)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for iter in range(niter):
nt_c = nt_a.bmm(nt_b)
end_event.record()
torch.cuda.synchronize()
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
return runtime
def sweep_n(ntensor, niter, dtype):
print("n, dtype, ntensor, gflop, runtime, tflop/s")
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
nt_a = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
)
nt_b = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
)
runtime = bench(nt_a, nt_b, niter)
tflop = n * n * n * ntensor * 2 / 1e12
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
parser.add_argument("--niter", default="10", type=int)
parser.add_argument("--ntensor", default="20", type=int)
args = parser.parse_args()
niter = args.niter
ntensor = args.ntensor
sweep_n(ntensor, niter, torch.float32)
sweep_n(ntensor, niter, torch.float16)