blob: 1ba84ce355df40e41cb5d443f725e0e1c36c9784 [file] [log] [blame]
import torch
import torch._C._te as te
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
class kernel_arena_scope(object):
def __enter__(self):
self.scope = te.KernelScope()
def __exit__(self, typ, val, traceback):
self.scope = None
unary_ops = [
("sin", torch.sin),
("cos", torch.cos),
("tan", torch.tan),
("asin", torch.asin),
("acos", torch.acos),
("atan", torch.atan),
("sinh", torch.sinh),
("cosh", torch.cosh),
("tanh", torch.tanh),
("sigmoid", torch.sigmoid),
("exp", torch.exp),
("expm1", torch.expm1),
("expm1", torch.expm1),
("abs", torch.abs),
("log", torch.log),
("fast_log", torch.log),
("log2", torch.log2),
("log10", torch.log10),
("log1p", torch.log1p),
("erf", torch.erf),
("erfc", torch.erfc),
("sqrt", torch.sqrt),
("rsqrt", torch.rsqrt),
("ceil", torch.ceil),
("floor", torch.floor),
("round", torch.round),
("trunc", torch.trunc),
("lgamma", torch.lgamma),
# ("frac", torch.frac), # seems unimplemented
# ("isnan", torch.isnan), # no out variant
]
def gen_unary_nnc_fun(nnc_name):
def nnc_fun(A, B):
def compute(i, j):
return getattr(A.load([i, j]), nnc_name)()
return compute
return nnc_fun
def gen_unary_torch_fun(torch_op):
def torch_fun(a, b, out):
def fun():
return torch_op(a, out=out)
return fun
return torch_fun
def gen_binary_nnc_fun(fn):
def nnc_fun(A, B):
def compute(i, j):
return fn(A.load([i, j]), B.load([i, j]))
return compute
return nnc_fun
def gen_binary_torch_fun(fn):
def pt_fun(a, b, out):
def fun():
return fn(a, b, out=out)
return fun
return pt_fun
def gen_int_comparison_tensors(N, M):
return (torch.randint(0, 3, (N, M)), torch.randint(0, 3, (N, M)), torch.empty((N, M), dtype=torch.bool))
def gen_float_comparison_tensors(N, M):
return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
te_bool = te.Dtype.Bool
binary_ops = [
('add', (lambda a, b: a + b), torch.add),
('mul', (lambda a, b: a * b), torch.mul),
('sub', (lambda a, b: a - b), torch.sub),
('div', (lambda a, b: a / b), torch.div),
('eq', (lambda a, b: te.Cast.make(te_bool, a == b)), torch.eq, gen_int_comparison_tensors),
('gt', (lambda a, b: te.Cast.make(te_bool, a > b)), torch.gt, gen_float_comparison_tensors),
('lt', (lambda a, b: te.Cast.make(te_bool, a < b)), torch.lt, gen_float_comparison_tensors),
('gte', (lambda a, b: te.Cast.make(te_bool, a >= b)), torch.greater_equal, gen_float_comparison_tensors),
('lte', (lambda a, b: te.Cast.make(te_bool, a <= b)), torch.less_equal, gen_float_comparison_tensors),
# ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
# ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
]
def nnc_relu(A, B):
def f(i, j):
return torch._C._te.ifThenElse(A.load([i, j]) < torch._C._te.ExprHandle.float(0),
torch._C._te.ExprHandle.float(0), A.load([i, j]))
return f
def pt_relu(a, b, c):
return torch.relu(a)
custom_ops = [
('relu', nnc_relu, pt_relu),
# ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
# ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
]
def gen_custom_torch_fun(fn):
def pt_fun(a, b, out):
def fun():
return fn(a, b, out)
return fun
return pt_fun
def normalize_benchmarks(ops):
return [i + (None,) if len(i) == 3 else i for i in ops]
names = []
nnc_fns = []
pt_fns = []
shape_fns = []
for nnc_name, pt_op in unary_ops:
names.append(nnc_name)
nnc_fns.append(gen_unary_nnc_fun(nnc_name))
pt_fns.append(gen_unary_torch_fun(pt_op))
shape_fns.append(None)
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
names.append(name)
nnc_fns.append(gen_binary_nnc_fun(lmbda))
pt_fns.append(gen_binary_torch_fun(pt_fn))
shape_fns.append(shape_fn)
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
names.append(name)
nnc_fns.append(lmbda)
pt_fns.append(gen_custom_torch_fun(pt_fn))
shape_fns.append(shape_fn)
benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
def run_benchmarks(benchmarks, sizes):
df = pd.DataFrame(columns=['name', 'N', 'M', 'nnc_time', 'torch_time', 'ratio'])
with torch.no_grad():
for name, nnc_fun, torch_fun, shape_fn in benchmarks:
for N, M in sizes:
iters = int(1e6 / (N + M))
with kernel_arena_scope():
if shape_fn is None:
tA = torch.rand(M, N).clamp(0.01, 0.99)
tB = torch.rand(M, N).clamp(0.01, 0.99)
tX = torch.empty(M, N)
tR = torch.empty(M, N)
else:
tA, tB, tX = shape_fn(M, N)
tR = tX.clone()
def get_nnc_type(dtype):
if dtype == torch.float:
return torch._C._te.Dtype.Float
elif dtype == torch.long:
return torch._C._te.Dtype.Long
dtype = get_nnc_type(tA.dtype)
dM = torch._C._te.ExprHandle.int(M)
dN = torch._C._te.ExprHandle.int(N)
A = torch._C._te.Placeholder('A', dtype, [dM, dN])
B = torch._C._te.Placeholder('B', dtype, [dM, dN])
dim_args = [torch._C._te.DimArg(*args) for args in [(dM, 'm'), (dN, 'n')]]
compute = nnc_fun(A, B)
X = torch._C._te.Compute('X', dim_args, compute)
loopnest = torch._C._te.LoopNest([X])
loopnest.prepare_for_codegen()
stmt = torch._C._te.simplify(loopnest.root_stmt())
cg = torch._C._te.construct_codegen('llvm', stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]])
# warmup
for _ in range(10):
cg.call([tA, tB, tX])
start = time.time()
for it in range(iters):
cg.call([tA, tB, tX])
time1 = time.time() - start
fn = torch_fun(tA, tB, tR)
# warmup
for _ in range(10):
tR = fn()
start = time.time()
for it in range(iters):
tR = fn()
time2 = time.time() - start
df = df.append({'name': name, 'N': N, 'M': M, 'nnc_time': time1,
'torch_time': time2, 'ratio': time2 / time1}, ignore_index=True)
print(name, N, M)
print(time2 / time1, time1, time2)
print()
def check_correctness(a, b):
if not np.allclose(a, b):
print(name)
assert(np.allclose(a, b))
check_correctness(tX, tR)
return df
def dump_plot(df, sizes):
keys = []
vals = []
indexed = df[df['N'] == df['M']]
for index, row in indexed.iterrows():
keys.append(row['name'])
vals.append(row['ratio'])
keys = keys[::len(sizes)]
sns.set(rc={'figure.figsize' : (5.0, len(keys) * 0.5)})
cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
np_vals = np.array([vals]).reshape(-1, len(sizes))
g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
plt.yticks(rotation=0)
plt.title('PyTorch performance divided by NNC performance (single core)')
plt.xlabel('Size of NxN matrix')
plt.ylabel('Operation')
g.set_yticklabels(keys)
g.set_xticklabels(sizes)
plt.savefig('nnc.png')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Runs NNC microbenchmarks')
parser.add_argument('--multi_threaded', action='store_true', help='Run with more than one thread')
args = parser.parse_args()
if not args.multi_threaded:
torch.set_num_threads(1)
sizes = [1, 4, 16, 64, 256, 1024]
df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
dump_plot(df, sizes)