blob: 8ff963f34e667c42fee212876a839b20fb856640 [file] [log] [blame]
import re
import torch
"""
Instructions:
1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
2. python test/xfail_suggester.py
"""
with open("result.txt") as f:
lines = f.readlines()
failed = [line for line in lines if line.startswith("FAILED")]
p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)") # noqa: W605
def get_failed_test(line):
m = p.match(line)
if m is None:
return None
return m.group(1)
base_names = {
"test_grad_",
"test_vjp_",
"test_vmapvjp_",
"test_vmapvjp_has_batch_rule_",
"test_vjpvmap_",
"test_jvp_",
"test_vmapjvp_",
"test_vmapjvpall_has_batch_rule_",
"test_vmapjvpall_",
"test_jvpvjp_",
"test_vjpvjp_",
"test_decomposition_",
"test_make_fx_exhaustive_",
"test_vmap_exhaustive_",
"test_op_has_batch_rule_",
"test_vmap_autograd_grad_",
}
failed_tests = [get_failed_test(line) for line in lines]
failed_tests = [match for match in failed_tests if match is not None]
failed_tests = sorted(failed_tests)
suggested_xfails = {}
def remove_device_dtype(test):
return "_".join(test.split("_")[:-2])
def belongs_to_base(test, base):
if not test.startswith(base):
return False
candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
for candidate in candidates:
if test.startswith(candidate):
return False
return True
def parse_namespace(base):
mappings = {
"nn_functional_": "nn.functional",
"fft_": "fft",
"linalg_": "linalg",
"_masked_": "_masked",
"sparse_": "sparse",
"special_": "special",
}
for heading in mappings.keys():
if base.startswith(heading):
return mappings[heading], base[len(heading) :]
return None, base
def get_torch_module(namespace):
if namespace is None:
return torch
if namespace == "nn.functional":
return torch.nn.functional
return getattr(torch, namespace)
def parse_base(base):
namespace, rest = parse_namespace(base)
apis = dir(get_torch_module(namespace))
apis = sorted(apis, key=lambda x: -len(x))
api = rest
variant = ""
for candidate in apis:
if rest.startswith(candidate):
api = candidate
variant = rest[len(candidate) + 1 :]
break
print(base, namespace, api, variant)
return namespace, api, variant
def any_starts_with(strs, thing):
for s in strs:
if s.startswith(thing):
return True
return False
def get_suggested_xfails(base, tests):
result = []
tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)]
base_tests = {remove_device_dtype(test) for test in tests}
tests = set(tests)
for base in base_tests:
cpu_variant = base + "_cpu_float32"
cuda_variant = base + "_cuda_float32"
namespace, api, variant = parse_base(base)
if namespace is None:
api = api
else:
api = f"{namespace}.{api}"
if cpu_variant in tests and cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}'),")
continue
if cpu_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
continue
if cuda_variant in tests:
result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
continue
result.append(f"skip('{api}', '{variant}',")
return result
result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
for k, v in result.items():
print("=" * 50)
print(k)
print("=" * 50)
print("\n".join(v))