blob: 71502ae84ff9e0110c16e91cb06d2b2d7e122e50 [file] [log] [blame]
import yaml
import csv
import torch
from collections import defaultdict
def get_ops_for_key(key):
# Needs modified PyTorch C++ code to work
if key is None:
ops = torch._C._dispatch_get_registrations_for_dispatch_key()
else:
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
cleaned_ops = []
for i in ops:
if 'aten::' not in i:
continue
cleaned_ops.append(i[6:].strip())
return set(cleaned_ops)
def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
from collections import defaultdict
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
else:
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
for op in ops:
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
continue
if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \
'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
continue
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f:
for op in ops:
info = [
op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)
] + [check(op) for check in special_op_lists]
f.write(','.join([str(i) for i in info]) + '\n')
def name_check(lst):
return lambda x: x['name'] in lst
def full_name_check(lst):
return lambda x: x['full_name'] in lst
# Generates batching rule data
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')
def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[:-len(suffix)]
return input_string
def remove_prefix(input_string, prefix):
if prefix and input_string.startswith(prefix):
return input_string[len(prefix):]
return input_string
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('count_ops.txt', 'r') as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
def count_fn(x):
return opinfo_counts[x['full_name']]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('public_api', 'r') as f:
ref_api = [i.strip() for i in f.readlines()]
def has_ref_impl(x):
name = x['name']
for prefix in ["linalg_", "special_"]:
name = remove_prefix(name, prefix)
prefixes = ['nn.functional', 'fft', 'special', 'linalg']
return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt')