|  | import argparse | 
|  | import random | 
|  |  | 
|  | import torch | 
|  |  | 
|  |  | 
|  | def bench(nt_a, nt_b, niter): | 
|  | # Warmup | 
|  | nt_c = nt_a.bmm(nt_b) | 
|  |  | 
|  | torch.cuda.synchronize() | 
|  | start_event = torch.cuda.Event(enable_timing=True) | 
|  | end_event = torch.cuda.Event(enable_timing=True) | 
|  | start_event.record() | 
|  | for iter in range(niter): | 
|  | nt_c = nt_a.bmm(nt_b) | 
|  | end_event.record() | 
|  | torch.cuda.synchronize() | 
|  | runtime = (start_event.elapsed_time(end_event)) / niter | 
|  | return runtime | 
|  |  | 
|  |  | 
|  | def sweep_n(niter, dtype): | 
|  | for ntensor in [4, 8, 16, 32, 64, 128, 256]: | 
|  | tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)] | 
|  | nt_a = torch.nested.nested_tensor( | 
|  | tensors, | 
|  | dtype=dtype, | 
|  | device="cuda", | 
|  | ) | 
|  | nt_b = torch.nested.nested_tensor( | 
|  | [t.t() for t in tensors], | 
|  | dtype=dtype, | 
|  | device="cuda", | 
|  | ) | 
|  | runtime = bench(nt_a, nt_b, niter) | 
|  | nt_a_size = torch.ops.aten._nested_tensor_size(nt_a) | 
|  | lengths = nt_a_size[:, 1] | 
|  | print( | 
|  | ",".join( | 
|  | map( | 
|  | str, | 
|  | [ | 
|  | ntensor, | 
|  | dtype, | 
|  | lengths.min().item(), | 
|  | lengths.float().mean().item(), | 
|  | lengths.max().item(), | 
|  | runtime, | 
|  | ], | 
|  | ) | 
|  | ) | 
|  | ) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | random.seed(123) | 
|  | parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark") | 
|  | parser.add_argument("--niter", default="10", type=int) | 
|  |  | 
|  | args = parser.parse_args() | 
|  | niter = args.niter | 
|  |  | 
|  | print("ntensor,dtype,min_length,mean_length,max_length,runtime") | 
|  | sweep_n(niter, torch.float32) | 
|  | sweep_n(niter, torch.float16) |