blob: 4ae552a44bd3c773d5f50968af0715dc315e0551 [file] [log] [blame]
import re
import torch
"""
Instructions:
1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
2. python test/xfail_suggester.py
"""
with open('result.txt') as f:
lines = f.readlines()
failed = [line for line in lines if line.startswith('FAILED')]
p = re.compile('FAILED test/test_\w+.py::\w+::(\S+)') # noqa: W605
def get_failed_test(line):
m = p.match(line)
if m is None:
return None
return m.group(1)
base_names = {
'test_grad_',
'test_vjp_',
'test_vmapvjp_',
'test_vmapvjp_has_batch_rule_',
'test_vjpvmap_',
'test_jvp_',
'test_vmapjvp_',
'test_vmapjvpall_has_batch_rule_',
'test_vmapjvpall_',
'test_jvpvjp_',
'test_vjpvjp_',
'test_decomposition_',
'test_make_fx_exhaustive_',
'test_vmap_exhaustive_',
'test_op_has_batch_rule_',
'test_vmap_autograd_grad_',
}
failed_tests = [get_failed_test(line) for line in lines]
failed_tests = [match for match in failed_tests if match is not None]
failed_tests = sorted(failed_tests)
suggested_xfails = {}
def remove_device_dtype(test):
return '_'.join(test.split('_')[:-2])
def belongs_to_base(test, base):
if not test.startswith(base):
return False
candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
for candidate in candidates:
if test.startswith(candidate):
return False
return True
def parse_namespace(base):
mappings = {
'nn_functional_': 'nn.functional',
'fft_': 'fft',
'linalg_': 'linalg',
'_masked_': '_masked',
'sparse_': 'sparse',
'speical_': 'special',
}
for heading in mappings.keys():
if base.startswith(heading):
return mappings[heading], base[len(heading):]
return None, base
def get_torch_module(namespace):
if namespace is None:
return torch
if namespace == 'nn.functional':
return torch.nn.functional
return getattr(torch, namespace)
def parse_base(base):
namespace, rest = parse_namespace(base)
apis = dir(get_torch_module(namespace))
apis = sorted(apis, key=lambda x: -len(x))
api = rest
variant = ''
for candidate in apis:
if rest.startswith(candidate):
api = candidate
variant = rest[len(candidate) + 1:]
break
print(base, namespace, api, variant)
return namespace, api, variant
def any_starts_with(strs, thing):
for s in strs:
if s.startswith(thing):
return True
return False
def get_suggested_xfails(base, tests):
result = []
tests = [test[len(base):] for test in tests if
belongs_to_base(test, base)]
base_tests = set([remove_device_dtype(test) for test in tests])
tests = set(tests)
for base in base_tests:
cpu_variant = base + '_cpu_float32'
cuda_variant = base + '_cuda_float32'
namespace, api, variant = parse_base(base)
if namespace is None:
api = api
else:
api = f'{namespace}.{api}'
if cpu_variant in tests and cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}'),")
continue
if cpu_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
continue
if cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
continue
result.append(f"skip('{api}', '{variant}',")
return result
result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
for k, v in result.items():
print('=' * 50)
print(k)
print('=' * 50)
print('\n'.join(v))