| import re |
| |
| import torch |
| |
| |
| """ |
| Instructions: |
| |
| 1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt |
| 2. python test/xfail_suggester.py |
| """ |
| |
| with open("result.txt") as f: |
| lines = f.readlines() |
| |
| failed = [line for line in lines if line.startswith("FAILED")] |
| p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)") # noqa: W605 |
| |
| |
| def get_failed_test(line): |
| m = p.match(line) |
| if m is None: |
| return None |
| return m.group(1) |
| |
| |
| base_names = { |
| "test_grad_", |
| "test_vjp_", |
| "test_vmapvjp_", |
| "test_vmapvjp_has_batch_rule_", |
| "test_vjpvmap_", |
| "test_jvp_", |
| "test_vmapjvp_", |
| "test_vmapjvpall_has_batch_rule_", |
| "test_vmapjvpall_", |
| "test_jvpvjp_", |
| "test_vjpvjp_", |
| "test_decomposition_", |
| "test_make_fx_exhaustive_", |
| "test_vmap_exhaustive_", |
| "test_op_has_batch_rule_", |
| "test_vmap_autograd_grad_", |
| } |
| |
| failed_tests = [get_failed_test(line) for line in lines] |
| failed_tests = [match for match in failed_tests if match is not None] |
| failed_tests = sorted(failed_tests) |
| |
| suggested_xfails = {} |
| |
| |
| def remove_device_dtype(test): |
| return "_".join(test.split("_")[:-2]) |
| |
| |
| def belongs_to_base(test, base): |
| if not test.startswith(base): |
| return False |
| candidates = [try_base for try_base in base_names if len(try_base) > len(base)] |
| for candidate in candidates: |
| if test.startswith(candidate): |
| return False |
| return True |
| |
| |
| def parse_namespace(base): |
| mappings = { |
| "nn_functional_": "nn.functional", |
| "fft_": "fft", |
| "linalg_": "linalg", |
| "_masked_": "_masked", |
| "sparse_": "sparse", |
| "special_": "special", |
| } |
| for heading in mappings.keys(): |
| if base.startswith(heading): |
| return mappings[heading], base[len(heading) :] |
| return None, base |
| |
| |
| def get_torch_module(namespace): |
| if namespace is None: |
| return torch |
| if namespace == "nn.functional": |
| return torch.nn.functional |
| return getattr(torch, namespace) |
| |
| |
| def parse_base(base): |
| namespace, rest = parse_namespace(base) |
| |
| apis = dir(get_torch_module(namespace)) |
| apis = sorted(apis, key=lambda x: -len(x)) |
| |
| api = rest |
| variant = "" |
| for candidate in apis: |
| if rest.startswith(candidate): |
| api = candidate |
| variant = rest[len(candidate) + 1 :] |
| break |
| print(base, namespace, api, variant) |
| return namespace, api, variant |
| |
| |
| def any_starts_with(strs, thing): |
| for s in strs: |
| if s.startswith(thing): |
| return True |
| return False |
| |
| |
| def get_suggested_xfails(base, tests): |
| result = [] |
| tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)] |
| |
| base_tests = {remove_device_dtype(test) for test in tests} |
| tests = set(tests) |
| for base in base_tests: |
| cpu_variant = base + "_cpu_float32" |
| cuda_variant = base + "_cuda_float32" |
| namespace, api, variant = parse_base(base) |
| if namespace is None: |
| api = api |
| else: |
| api = f"{namespace}.{api}" |
| if cpu_variant in tests and cuda_variant in tests: |
| result.append(f"xfail('{api}', '{variant}'),") |
| continue |
| if cpu_variant in tests: |
| result.append(f"xfail('{api}', '{variant}', device_type='cpu'),") |
| continue |
| if cuda_variant in tests: |
| result.append(f"xfail('{api}', '{variant}', device_type='cuda'),") |
| continue |
| result.append(f"skip('{api}', '{variant}',") |
| return result |
| |
| |
| result = {base: get_suggested_xfails(base, failed_tests) for base in base_names} |
| for k, v in result.items(): |
| print("=" * 50) |
| print(k) |
| print("=" * 50) |
| print("\n".join(v)) |