| import torch |
| import torch._C._te as te |
| import time |
| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import argparse |
| |
| class kernel_arena_scope(object): |
| def __enter__(self): |
| self.scope = te.KernelScope() |
| |
| def __exit__(self, typ, val, traceback): |
| self.scope = None |
| |
| unary_ops = [ |
| ("sin", torch.sin), |
| ("cos", torch.cos), |
| ("tan", torch.tan), |
| ("asin", torch.asin), |
| ("acos", torch.acos), |
| ("atan", torch.atan), |
| ("sinh", torch.sinh), |
| ("cosh", torch.cosh), |
| ("tanh", torch.tanh), |
| ("sigmoid", torch.sigmoid), |
| ("exp", torch.exp), |
| ("expm1", torch.expm1), |
| ("expm1", torch.expm1), |
| ("abs", torch.abs), |
| ("log", torch.log), |
| ("fast_log", torch.log), |
| ("log2", torch.log2), |
| ("log10", torch.log10), |
| ("log1p", torch.log1p), |
| ("erf", torch.erf), |
| ("erfc", torch.erfc), |
| ("sqrt", torch.sqrt), |
| ("rsqrt", torch.rsqrt), |
| ("ceil", torch.ceil), |
| ("floor", torch.floor), |
| ("round", torch.round), |
| ("trunc", torch.trunc), |
| ("lgamma", torch.lgamma), |
| # ("frac", torch.frac), # seems unimplemented |
| # ("isnan", torch.isnan), # no out variant |
| ] |
| |
| def gen_unary_nnc_fun(nnc_name): |
| def nnc_fun(A, B): |
| def compute(i, j): |
| return getattr(A.load([i, j]), nnc_name)() |
| return compute |
| return nnc_fun |
| |
| def gen_unary_torch_fun(torch_op): |
| def torch_fun(a, b, out): |
| def fun(): |
| return torch_op(a, out=out) |
| return fun |
| return torch_fun |
| |
| |
| def gen_binary_nnc_fun(fn): |
| def nnc_fun(A, B): |
| def compute(i, j): |
| return fn(A.load([i, j]), B.load([i, j])) |
| return compute |
| return nnc_fun |
| |
| def gen_binary_torch_fun(fn): |
| def pt_fun(a, b, out): |
| def fun(): |
| return fn(a, b, out=out) |
| return fun |
| return pt_fun |
| |
| def gen_int_comparison_tensors(N, M): |
| return (torch.randint(0, 3, (N, M)), torch.randint(0, 3, (N, M)), torch.empty((N, M), dtype=torch.bool)) |
| |
| def gen_float_comparison_tensors(N, M): |
| return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool)) |
| |
| |
| te_bool = te.Dtype.Bool |
| binary_ops = [ |
| ('add', (lambda a, b: a + b), torch.add), |
| ('mul', (lambda a, b: a * b), torch.mul), |
| ('sub', (lambda a, b: a - b), torch.sub), |
| ('div', (lambda a, b: a / b), torch.div), |
| ('eq', (lambda a, b: te.Cast.make(te_bool, a == b)), torch.eq, gen_int_comparison_tensors), |
| ('gt', (lambda a, b: te.Cast.make(te_bool, a > b)), torch.gt, gen_float_comparison_tensors), |
| ('lt', (lambda a, b: te.Cast.make(te_bool, a < b)), torch.lt, gen_float_comparison_tensors), |
| ('gte', (lambda a, b: te.Cast.make(te_bool, a >= b)), torch.greater_equal, gen_float_comparison_tensors), |
| ('lte', (lambda a, b: te.Cast.make(te_bool, a <= b)), torch.less_equal, gen_float_comparison_tensors), |
| # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent |
| # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test |
| ] |
| |
| |
| def nnc_relu(A, B): |
| def f(i, j): |
| return torch._C._te.ifThenElse(A.load([i, j]) < torch._C._te.ExprHandle.float(0), |
| torch._C._te.ExprHandle.float(0), A.load([i, j])) |
| return f |
| |
| def pt_relu(a, b, c): |
| return torch.relu(a) |
| custom_ops = [ |
| ('relu', nnc_relu, pt_relu), |
| # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu) |
| # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c)) |
| ] |
| |
| |
| def gen_custom_torch_fun(fn): |
| def pt_fun(a, b, out): |
| def fun(): |
| return fn(a, b, out) |
| return fun |
| return pt_fun |
| |
| def normalize_benchmarks(ops): |
| return [i + (None,) if len(i) == 3 else i for i in ops] |
| |
| names = [] |
| nnc_fns = [] |
| pt_fns = [] |
| shape_fns = [] |
| |
| for nnc_name, pt_op in unary_ops: |
| names.append(nnc_name) |
| nnc_fns.append(gen_unary_nnc_fun(nnc_name)) |
| pt_fns.append(gen_unary_torch_fun(pt_op)) |
| shape_fns.append(None) |
| |
| for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops): |
| names.append(name) |
| nnc_fns.append(gen_binary_nnc_fun(lmbda)) |
| pt_fns.append(gen_binary_torch_fun(pt_fn)) |
| shape_fns.append(shape_fn) |
| |
| for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops): |
| names.append(name) |
| nnc_fns.append(lmbda) |
| pt_fns.append(gen_custom_torch_fun(pt_fn)) |
| shape_fns.append(shape_fn) |
| |
| benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns)) |
| |
| def run_benchmarks(benchmarks, sizes): |
| df = pd.DataFrame(columns=['name', 'N', 'M', 'nnc_time', 'torch_time', 'ratio']) |
| with torch.no_grad(): |
| for name, nnc_fun, torch_fun, shape_fn in benchmarks: |
| for N, M in sizes: |
| iters = int(1e6 / (N + M)) |
| with kernel_arena_scope(): |
| if shape_fn is None: |
| tA = torch.rand(M, N).clamp(0.01, 0.99) |
| tB = torch.rand(M, N).clamp(0.01, 0.99) |
| tX = torch.empty(M, N) |
| tR = torch.empty(M, N) |
| else: |
| tA, tB, tX = shape_fn(M, N) |
| tR = tX.clone() |
| |
| def get_nnc_type(dtype): |
| if dtype == torch.float: |
| return torch._C._te.Dtype.Float |
| elif dtype == torch.long: |
| return torch._C._te.Dtype.Long |
| |
| dtype = get_nnc_type(tA.dtype) |
| |
| dM = torch._C._te.ExprHandle.int(M) |
| dN = torch._C._te.ExprHandle.int(N) |
| |
| A = torch._C._te.Placeholder('A', dtype, [dM, dN]) |
| B = torch._C._te.Placeholder('B', dtype, [dM, dN]) |
| |
| dim_args = [torch._C._te.DimArg(*args) for args in [(dM, 'm'), (dN, 'n')]] |
| |
| compute = nnc_fun(A, B) |
| X = torch._C._te.Compute('X', dim_args, compute) |
| loopnest = torch._C._te.LoopNest([X]) |
| loopnest.prepare_for_codegen() |
| stmt = torch._C._te.simplify(loopnest.root_stmt()) |
| cg = torch._C._te.construct_codegen('llvm', stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]) |
| |
| |
| # warmup |
| for _ in range(10): |
| cg.call([tA, tB, tX]) |
| start = time.time() |
| for it in range(iters): |
| cg.call([tA, tB, tX]) |
| time1 = time.time() - start |
| |
| |
| fn = torch_fun(tA, tB, tR) |
| # warmup |
| for _ in range(10): |
| tR = fn() |
| start = time.time() |
| for it in range(iters): |
| tR = fn() |
| time2 = time.time() - start |
| |
| df = df.append({'name': name, 'N': N, 'M': M, 'nnc_time': time1, |
| 'torch_time': time2, 'ratio': time2 / time1}, ignore_index=True) |
| print(name, N, M) |
| |
| print(time2 / time1, time1, time2) |
| print() |
| |
| def check_correctness(a, b): |
| if not np.allclose(a, b): |
| print(name) |
| assert(np.allclose(a, b)) |
| check_correctness(tX, tR) |
| return df |
| |
| def dump_plot(df, sizes): |
| keys = [] |
| vals = [] |
| indexed = df[df['N'] == df['M']] |
| for index, row in indexed.iterrows(): |
| keys.append(row['name']) |
| vals.append(row['ratio']) |
| |
| keys = keys[::len(sizes)] |
| sns.set(rc={'figure.figsize' : (5.0, len(keys) * 0.5)}) |
| |
| cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True) |
| np_vals = np.array([vals]).reshape(-1, len(sizes)) |
| g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True) |
| plt.yticks(rotation=0) |
| plt.title('PyTorch performance divided by NNC performance (single core)') |
| plt.xlabel('Size of NxN matrix') |
| plt.ylabel('Operation') |
| g.set_yticklabels(keys) |
| g.set_xticklabels(sizes) |
| |
| plt.savefig('nnc.png') |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Runs NNC microbenchmarks') |
| parser.add_argument('--multi_threaded', action='store_true', help='Run with more than one thread') |
| args = parser.parse_args() |
| if not args.multi_threaded: |
| torch.set_num_threads(1) |
| |
| sizes = [1, 4, 16, 64, 256, 1024] |
| df = run_benchmarks(benchmarks, [(i, i) for i in sizes]) |
| dump_plot(df, sizes) |