tree 35530e2a44bd78a4807ed0bb8e01623e5a6e966b
parent 1807e61724ef014ee14ad374636bcd21f1aa1832
author Qiao Zhang <zhangqiaorjc@google.com> 1621363360 -0700
committer TensorFlower Gardener <gardener@tensorflow.org> 1621364309 -0700

[JAX] Refactor jax_jit to avoid DevicePut on pruned args.

name                                  old cpu/op  new cpu/op  delta
eager_unary_dispatch                  35.7µs ± 2%  35.9µs ± 3%     ~     (p=0.841 n=5+5)
eager_unary                           36.4µs ± 2%  36.6µs ± 3%     ~     (p=0.421 n=5+5)
eager_binary_dispatch                 45.6µs ± 1%  46.1µs ± 2%     ~     (p=0.421 n=5+5)
eager_binary                          46.6µs ± 2%  47.0µs ± 5%     ~     (p=1.000 n=5+5)
jit_trivial_dispatch                  41.4µs ± 1%  41.4µs ± 0%     ~     (p=0.690 n=5+5)
jit_trivial                           42.4µs ± 1%  42.3µs ± 1%     ~     (p=0.841 n=5+5)
jit_simple_dispatch                   8.85µs ± 3%  9.15µs ± 3%     ~     (p=0.095 n=5+5)
jit_simple                            9.77µs ± 1%  9.82µs ± 2%     ~     (p=0.548 n=5+5)
jit_simple_many_args_dispatch_10      13.4µs ± 1%  13.6µs ± 3%     ~     (p=0.222 n=5+5)
jit_simple_many_args_10               14.0µs ± 2%  14.1µs ± 1%     ~     (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10    8.05µs ± 3%  8.07µs ± 4%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_10             9.53µs ± 2%  9.43µs ± 2%     ~     (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100     55.2µs ± 1%  54.8µs ± 2%     ~     (p=0.310 n=5+5)
jit_simple_many_args_100              55.8µs ± 1%  55.8µs ± 1%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100   14.3µs ± 4%  12.6µs ± 1%  -11.41%  (p=0.016 n=5+4)
jit_simple_pruned_args_100            14.8µs ± 1%  13.3µs ± 2%  -10.06%  (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000     489µs ± 1%   477µs ± 3%     ~     (p=0.056 n=5+5)
jit_simple_many_args_1000              495µs ± 3%   493µs ± 3%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000  85.0µs ± 3%  65.3µs ± 3%  -23.13%  (p=0.008 n=5+5)
jit_simple_pruned_args_1000           86.0µs ± 3%  66.4µs ± 3%  -22.78%  (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000    1.09ms ± 4%  1.03ms ± 3%   -5.97%  (p=0.016 n=5+5)
jit_simple_many_args_2000             1.07ms ± 3%  1.04ms ± 5%     ~     (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000   190µs ± 3%   144µs ± 3%  -23.96%  (p=0.008 n=5+5)
jit_simple_pruned_args_2000            195µs ± 4%   147µs ± 3%  -24.29%  (p=0.008 n=5+5)
jit_dispatch_without_transfer         76.0µs ± 1%  77.2µs ± 6%     ~     (p=0.310 n=5+5)
jit_dispatch_with_transfer            82.1µs ± 5%  81.3µs ± 2%     ~     (p=0.421 n=5+5)
sda_index_1                           8.83µs ± 1%  8.73µs ± 2%     ~     (p=0.222 n=5+5)

name                                  old time/op             new time/op             delta
eager_unary_dispatch                  35.7µs ± 2%             35.9µs ± 3%     ~             (p=0.841 n=5+5)
eager_unary                           36.5µs ± 2%             37.1µs ± 4%     ~             (p=0.222 n=5+5)
eager_binary_dispatch                 45.6µs ± 1%             46.1µs ± 2%     ~             (p=0.421 n=5+5)
eager_binary                          46.8µs ± 3%             47.1µs ± 5%     ~             (p=1.000 n=5+5)
jit_trivial_dispatch                  41.4µs ± 1%             41.4µs ± 0%     ~             (p=0.690 n=5+5)
jit_trivial                           42.4µs ± 1%             42.3µs ± 1%     ~             (p=0.841 n=5+5)
jit_simple_dispatch                   8.86µs ± 3%             9.15µs ± 3%     ~             (p=0.095 n=5+5)
jit_simple                            9.82µs ± 1%             9.91µs ± 0%     ~             (p=0.190 n=5+4)
jit_simple_many_args_dispatch_10      13.4µs ± 1%             13.6µs ± 4%     ~             (p=0.310 n=5+5)
jit_simple_many_args_10               14.1µs ± 2%             14.2µs ± 1%     ~             (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10    8.07µs ± 4%             8.07µs ± 4%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_10             9.59µs ± 2%             9.48µs ± 2%     ~             (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100     55.2µs ± 1%             54.8µs ± 2%     ~             (p=0.310 n=5+5)
jit_simple_many_args_100              55.9µs ± 1%             55.9µs ± 1%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100   14.3µs ± 5%             12.6µs ± 1%  -11.75%          (p=0.016 n=5+4)
jit_simple_pruned_args_100            14.8µs ± 2%             13.3µs ± 2%  -10.19%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000     489µs ± 1%              477µs ± 3%     ~             (p=0.056 n=5+5)
jit_simple_many_args_1000              495µs ± 3%              493µs ± 3%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000  85.0µs ± 3%             65.3µs ± 3%  -23.13%          (p=0.008 n=5+5)
jit_simple_pruned_args_1000           86.1µs ± 3%             66.5µs ± 2%  -22.72%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000    1.09ms ± 4%             1.03ms ± 3%   -5.96%          (p=0.016 n=5+5)
jit_simple_many_args_2000             1.07ms ± 3%             1.04ms ± 5%     ~             (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000   190µs ± 3%              144µs ± 3%  -23.97%          (p=0.008 n=5+5)
jit_simple_pruned_args_2000            195µs ± 4%              147µs ± 3%  -24.31%          (p=0.008 n=5+5)
jit_dispatch_without_transfer         1.41ms ± 1%             1.40ms ± 1%     ~             (p=0.095 n=5+5)
jit_dispatch_with_transfer            1.40ms ± 2%             1.40ms ± 2%     ~             (p=0.841 n=5+5)
sda_index_1                           8.83µs ± 1%             8.73µs ± 2%     ~             (p=0.222 n=5+5)

PiperOrigin-RevId: 374468578
Change-Id: I0a45af35b936a72f8271bd3e3a66e0d778619132
