blob: 616d1078ee7daa8d6d51a5ef772e618f0c8ee1de [file] [log] [blame]
from functools import partial
import itertools
import statistics
import timeit
import torch
TENSOR_SIZES = [1, 32, 128, 256, 512]
INTERNAL_ITER = 256
PARALLEL_TASKS_NUM = 4
N = 100
def loop_workload(x):
for i in range(INTERNAL_ITER):
x = torch.mm(x, x)
return x
traced_loop_workload = None
def run_profiler_benchmark_loop(input_x, use_cuda, profiling_enabled):
if profiling_enabled:
with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
traced_loop_workload(input_x)
else:
traced_loop_workload(input_x)
def parallel_task(x):
for i in range(int(INTERNAL_ITER / PARALLEL_TASKS_NUM)):
x = torch.mm(x, x)
return x
def parallel_workload(x):
futs = []
for i in range(PARALLEL_TASKS_NUM):
futs.append(torch.jit._fork(parallel_task, x))
for i in range(PARALLEL_TASKS_NUM):
torch.jit._wait(futs[i])
return x
traced_parallel_workload = None
def run_profiler_benchmark_parallel(input_x, use_cuda, profiling_enabled):
if profiling_enabled:
with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
traced_parallel_workload(input_x)
else:
traced_parallel_workload(input_x)
if __name__ == '__main__':
for workload_name in ["loop", "parallel"]:
print("Payload: {}; {} iterations, N = {}\n".format(
workload_name, INTERNAL_ITER, N))
for params in itertools.product([False, True], TENSOR_SIZES, [False, True]):
use_cuda = params[0]
profiling_tensor_size = params[1]
profiling_enabled = params[2]
if (use_cuda and not torch.cuda.is_available()):
continue
print("Profiling {}, tensor size {}x{}, use cuda: {}".format(
"enabled" if profiling_enabled else "disabled",
profiling_tensor_size, profiling_tensor_size, use_cuda))
input_x = torch.rand(profiling_tensor_size, profiling_tensor_size)
if use_cuda:
input_x = input_x.cuda()
workload = None
if workload_name == "loop":
workload = partial(
run_profiler_benchmark_loop, input_x, use_cuda, profiling_enabled)
traced_loop_workload = torch.jit.trace(loop_workload, input_x)
elif workload_name == "parallel":
workload = partial(
run_profiler_benchmark_parallel, input_x, use_cuda, profiling_enabled)
traced_parallel_workload = torch.jit.trace(
parallel_workload, input_x)
runtimes = timeit.repeat(workload, repeat=N, number=1)
avg_time = statistics.mean(runtimes) * 1000.0
stddev_time = statistics.stdev(runtimes) * 1000.0
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
avg_time, stddev_time))
if workload_name == "loop":
print("\ttime per iteration: {:.3f} ms".format(
avg_time / INTERNAL_ITER))
print()