| import timeit | |
| import torch.fx | |
| N = 100000 | |
| K = 1000 | |
| def huge_graph(): | |
| def fn(x): | |
| for _ in range(N): | |
| x = x.sin() | |
| return x | |
| return torch.fx.symbolic_trace(fn) | |
| def main(): | |
| g = huge_graph() | |
| def fn(): | |
| for n in g.graph.nodes: | |
| pass | |
| t = min(timeit.repeat(fn, number=K, repeat=3)) | |
| print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)") | |
| if __name__ == "__main__": | |
| main() |