| import timeit |
| import torch |
| |
| torch._C._jit_override_can_fuse_on_cpu(True) |
| torch._C._debug_set_fusion_group_inlining(False) |
| torch.set_num_threads(1) |
| |
| |
| def hardswish(x): |
| return x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0 |
| |
| |
| unary_ops = [ |
| hardswish, |
| torch._C._nn.hardswish, |
| torch.sigmoid, |
| torch.reciprocal, |
| torch.neg, |
| torch.relu, |
| torch.isnan, |
| torch.log, |
| torch.log10, |
| torch.log1p, |
| torch.log2, |
| torch.exp, |
| torch.expm1, |
| torch.erf, |
| torch.erfc, |
| torch.cos, |
| torch.sin, |
| torch.tan, |
| torch.acos, |
| torch.asin, |
| torch.cosh, |
| torch.sinh, |
| torch.atan, |
| torch.tanh, |
| torch.sqrt, |
| torch.rsqrt, |
| torch.abs, |
| torch.ceil, |
| torch.floor, |
| torch.round, |
| torch.trunc, |
| torch.lgamma, |
| ] |
| |
| print("{:20s} {:>10s} {:>10s} {:>10s}".format("op", "eager", "nnc", "speedup")) |
| |
| for op in unary_ops: |
| x = torch.rand((1024, 1024)) |
| traced = torch.jit.trace(lambda x: op(x), (x)) |
| |
| # Warmup. |
| warmup_iters = 8 |
| for _ in range(warmup_iters): |
| op(x) |
| traced(x) |
| |
| # Validate result. |
| torch.testing.assert_allclose(op(x), traced(x)) |
| |
| # Benchmark. |
| bench_iters = 100 |
| teager = timeit.timeit(stmt="op(x)", globals=globals(), number=bench_iters) |
| tjit = timeit.timeit(stmt="traced(x)", globals=globals(), number=bench_iters) |
| print(f"{op.__name__:20s} {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}") |