| from torch.utils.benchmark import Timer | |
| def time_with_torch_timer(fn, args, kwargs=None, iters=100): | |
| kwargs = kwargs or {} | |
| env = {"args": args, "kwargs": kwargs, "fn": fn} | |
| fn_call = "fn(*args, **kwargs)" | |
| # Measure end-to-end time | |
| timer = Timer(stmt=f"{fn_call}", globals=env) | |
| tt = timer.timeit(iters) | |
| return tt |