blob: 2c575c69418c759d482e3982800f65d6aedec99c [file] [log] [blame]
import timeit
import torch.fx
N = 100000
K = 1000
def huge_graph():
def fn(x):
for _ in range(N):
x = x.sin()
return x
return torch.fx.symbolic_trace(fn)
def main():
g = huge_graph()
def fn():
for n in g.graph.nodes:
pass
t = min(timeit.repeat(fn, number=K, repeat=3))
print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)")
if __name__ == "__main__":
main()