| import yaml |
| import csv |
| import torch |
| from collections import defaultdict |
| |
| |
| def get_ops_for_key(key): |
| # Needs modified PyTorch C++ code to work |
| if key is None: |
| ops = torch._C._dispatch_get_registrations_for_dispatch_key() |
| else: |
| ops = torch._C._dispatch_get_registrations_for_dispatch_key(key) |
| cleaned_ops = [] |
| for i in ops: |
| if 'aten::' not in i: |
| continue |
| cleaned_ops.append(i[6:].strip()) |
| return set(cleaned_ops) |
| |
| |
| def gen_data(special_op_lists, analysis_name): |
| all_ops = get_ops_for_key(None) |
| composite_ops = get_ops_for_key('CompositeImplicitAutograd') |
| noncomposite_ops = all_ops - composite_ops |
| |
| ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader) |
| |
| annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))} |
| from collections import defaultdict |
| |
| uniq_ops = [] |
| uniq_names = set() |
| overload_types = defaultdict(list) |
| cnt = 0 |
| for op in ops: |
| func_str = op['func'] |
| name = func_str[:func_str.index('(')] |
| if '.' in name: |
| uniq_name = name[:name.index('.')] |
| overload_types[name[name.index('.') + 1:]].append(name) |
| else: |
| uniq_name = name |
| op['name'] = uniq_name |
| full_name = func_str[:func_str.index('(')] |
| op['full_name'] = full_name |
| ret_type = func_str[func_str.index('->') + 3:] |
| op['ret_type'] = ret_type |
| cnt += 1 |
| if uniq_name in uniq_names: |
| continue |
| uniq_names.add(uniq_name) |
| uniq_ops.append(op) |
| |
| def annotate_ops(ops, is_unique): |
| categorization = defaultdict(int) |
| for op in ops: |
| if op['name'][-1] == '_': |
| categorization['inplace'] += 1 |
| op['meta'] = 'inplace' |
| continue |
| if not is_unique and 'a!' in op['func'].lower(): |
| categorization['out'] += 1 |
| op['meta'] = 'out' |
| continue |
| if 'conv' in op['name']: |
| categorization['conv'] += 1 |
| op['meta'] = 'conv' |
| continue |
| if 'pool' in op['name']: |
| categorization['pool'] += 1 |
| op['meta'] = 'pool' |
| continue |
| if 'backward' in op['name']: |
| categorization['backward'] += 1 |
| op['meta'] = 'backward' |
| continue |
| if op['name'][0] == '_' and op['name'][1] != '_': |
| categorization['private'] += 1 |
| op['meta'] = 'private' |
| continue |
| if 'batch_norm' in op['name']: |
| categorization['batch_norm'] += 1 |
| op['meta'] = 'batch_norm' |
| continue |
| if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']: |
| categorization['non_tensor'] += 1 |
| op['meta'] = 'non_tensor' |
| continue |
| if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \ |
| 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']: |
| categorization['backend'] += 1 |
| op['meta'] = 'backend' |
| continue |
| if op['name'] in annotated_ops: |
| categorization['core'] += 1 |
| op['meta'] = 'core ' + annotated_ops[op['name']] |
| continue |
| categorization['core'] += 1 |
| op['meta'] = 'core unknown' |
| return categorization |
| |
| annotate_ops(ops, is_unique=False) |
| with open(f"{analysis_name}", 'w') as f: |
| for op in ops: |
| info = [ |
| op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops) |
| ] + [check(op) for check in special_op_lists] |
| f.write(','.join([str(i) for i in info]) + '\n') |
| |
| |
| def name_check(lst): |
| return lambda x: x['name'] in lst |
| |
| |
| def full_name_check(lst): |
| return lambda x: x['full_name'] in lst |
| |
| |
| # Generates batching rule data |
| gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt') |
| |
| |
| def remove_suffix(input_string, suffix): |
| if suffix and input_string.endswith(suffix): |
| return input_string[:-len(suffix)] |
| return input_string |
| |
| def remove_prefix(input_string, prefix): |
| if prefix and input_string.startswith(prefix): |
| return input_string[len(prefix):] |
| return input_string |
| |
| |
| if True: |
| with open('run_ops.txt', 'r') as f: |
| opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] |
| with open('count_ops.txt', 'r') as f: |
| opinfo_counts = [i.strip() for i in f.readlines()] |
| opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)}) |
| |
| def count_fn(x): |
| return opinfo_counts[x['full_name']] |
| |
| with open('run_decompositions.txt', 'r') as f: |
| decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] |
| |
| with open('public_api', 'r') as f: |
| ref_api = [i.strip() for i in f.readlines()] |
| |
| def has_ref_impl(x): |
| name = x['name'] |
| for prefix in ["linalg_", "special_"]: |
| name = remove_prefix(name, prefix) |
| prefixes = ['nn.functional', 'fft', 'special', 'linalg'] |
| return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api |
| |
| gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt') |