| import copy |
| import enum |
| import pprint |
| import unittest |
| from enum import Enum |
| |
| # Importing these files make modifications to the op_db that we need |
| import test_ops # noqa: F401 |
| import test_vmap # noqa: F401 |
| import torch |
| import torch._functorch.top_operators_github_usage as top_ops |
| from functorch_additional_op_db import additional_op_db |
| from torch.testing._internal.common_device_type import toleranceOverride |
| from torch.testing._internal.common_methods_invocations import op_db |
| |
| all_overridable = list(torch.overrides.get_testing_overrides().keys()) |
| |
| public_docs = [ |
| (torch.nn.functional, "torch.nn.functional", "docs/source/nn.functional.rst"), |
| (torch.fft, "torch.fft", "docs/source/fft.rst"), |
| (torch.special, "torch.special", "docs/source/special.rst"), |
| (torch.linalg, "torch.linalg", "docs/source/linalg.rst"), |
| (torch, "torch", "docs/source/torch.rst"), |
| (torch.Tensor, "torch.Tensor", "docs/source/tensors.rst"), |
| ] |
| |
| # torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different |
| |
| |
| def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"): |
| results = {} |
| all_overridable_apis = set(torch.overrides.get_testing_overrides().keys()) |
| for module, module_name, src in public_docs: |
| with open(f"{pytorch_root}/{src}") as f: |
| lines = f.readlines() |
| # APIs eitehr begin with 4 spaces or ".. autofunction::" |
| api_lines1 = [line.strip() for line in lines if line.startswith(" " * 4)] |
| api_lines2 = [ |
| line.strip()[len(".. autofunction:: ") :] |
| for line in lines |
| if line.startswith(".. autofunction::") |
| ] |
| lines = api_lines1 + api_lines2 |
| lines = [line[7:] if line.startswith("Tensor.") else line for line in lines] |
| lines = [line for line in lines if hasattr(module, line)] |
| for line in lines: |
| api = getattr(module, line) |
| if api in all_overridable_apis: |
| results[f"{module_name}.{line}"] = api |
| return results |
| |
| |
| denylist = { |
| "torch.Tensor.data_ptr", |
| "torch.Tensor.dim", |
| "torch.Tensor.element_size", |
| "torch.Tensor.backward", |
| "torch.Tensor.as_strided", |
| "torch.Tensor.register_hook", |
| "torch.Tensor.record_stream", |
| "torch.Tensor.qscheme", |
| "torch.Tensor.ndimension", |
| "torch.Tensor.smm", |
| "torch.Tensor.sspaddmm", |
| "torch.Tensor.retain_grad", |
| "torch.Tensor.sparse_mask", |
| "torch.Tensor.sparse_dim", |
| "torch.Tensor.dense_dim", |
| "torch.Tensor.values", |
| "torch.Tensor.indices", |
| "torch.Tensor.numel", |
| "torch.Tensor.size", |
| "torch.Tensor.nelement", |
| "torch.Tensor.q_scale", |
| "torch.Tensor.q_zero_point", |
| "torch.Tensor.q_per_channel_scales", |
| "torch.Tensor.q_per_channel_zero_points", |
| "torch.Tensor.q_per_channel_axis", |
| "torch.Tensor.int_repr", |
| "torch.Tensor.to_sparse", |
| "torch.Tensor.is_inference", |
| "torch.Tensor.storage", |
| "torch.Tensor.storage_type", |
| } |
| |
| |
| def get_method_only_ops_we_care_about(): |
| apis = get_public_overridable_apis() |
| result = [] |
| for key in apis.keys(): |
| if not key.startswith("torch.Tensor"): |
| continue |
| if key in denylist: |
| continue |
| api = key.split(".")[2] |
| # filter out in-place |
| if api.endswith("_"): |
| continue |
| if f"torch.{api}" not in apis.keys(): |
| result.append(api) |
| return result |
| |
| |
| # Deduplicates torch.abs and Tensor.abs |
| |
| |
| def get_public_overridable_ops(): |
| results = get_public_overridable_apis() |
| cpy = copy.deepcopy(results) |
| for key in cpy.keys(): |
| if not key.startswith("torch.Tensor"): |
| continue |
| api = key.split(".")[2] |
| if f"torch.{api}" in results.keys(): |
| del results[key] |
| return results |
| |
| |
| def get_public_overridable_outplace_ops(): |
| results = get_public_overridable_ops() |
| cpy = copy.deepcopy(results) |
| for key in cpy.keys(): |
| # NB: there are no dunder methods bcs we don't document those |
| if key.endswith("_"): |
| del results[key] |
| return results |
| |
| |
| def get_public_overridable_outplace_we_care_about(): |
| results = get_public_overridable_outplace_ops() |
| cpy = copy.deepcopy(results) |
| for key in cpy.keys(): |
| # quantization |
| if "quant" in key or ".q_" in key: |
| del results[key] |
| |
| # is_cpu, etc. It doesn't make sense to have OpInfos for these |
| if ".is_" in key: |
| del results[key] |
| |
| if key in denylist and key in results: |
| del results[key] |
| return results |
| |
| |
| # e.g. nn.functional.softmax |
| |
| |
| def get_op(dotted_name): |
| names = dotted_name.split(".") |
| mod = torch |
| for name in names: |
| if not hasattr(mod, name): |
| return None |
| mod = getattr(mod, name) |
| return mod |
| |
| |
| # Maps function -> [OpInfo] |
| |
| |
| def get_ops_covered_by_opinfos(): |
| ops = {} |
| |
| def safe_append(dct, key, val): |
| if key in dct: |
| dct[key].append(val) |
| else: |
| dct[key] = [val] |
| |
| for opinfo in op_db: |
| func_op = get_op(opinfo.name) |
| if func_op: |
| safe_append(ops, func_op, opinfo) |
| if opinfo.method_variant: |
| safe_append(ops, opinfo.method_variant, opinfo) |
| if opinfo.inplace_variant: |
| safe_append(ops, opinfo.inplace_variant, opinfo) |
| for alias in opinfo.aliases: |
| safe_append(ops, alias.op, opinfo) |
| return ops |
| |
| |
| factory_fns = { |
| "tensor", |
| "zeros", |
| "ones", |
| "randn", |
| "arange", |
| "rand", |
| "empty", |
| "randperm", |
| "linspace", |
| "logspace", |
| "hann_window", |
| "full", |
| "eye", |
| "blackman_window", |
| "bartlett_window", |
| "randint", |
| "range", |
| } |
| |
| |
| def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False): |
| denylist = set( |
| { |
| # These are either not real "operators", factory functions |
| # that trivially work, or not-documented ops. |
| "load", |
| "no_grad", |
| "save", |
| "from_numpy", |
| "manual_seed", |
| "set_grad_enabled", |
| "set_default_tensor_type", |
| "set_num_threads", |
| "set_printoptions", |
| "numel", |
| "set_default_dtype", |
| "sparse_coo_tensor", |
| "set_rng_state", |
| "get_rng_state", |
| "get_default_dtype", |
| "initial_seed", |
| "get_num_threads", |
| "quantize_per_tensor", |
| "hann_window", |
| "is_tensor", |
| "as_tensor", |
| "equal", |
| "enable_grad", |
| "seed", |
| "is_storage", |
| "is_floating_point", |
| "nn.functional.torch", |
| "set_flush_denormal", |
| "set_num_interop_threads", |
| "dequantize", |
| "get_num_interop_threads", |
| "nn.functional.math", |
| "nn.functional.threshold_", |
| "nn.functional.selu_", |
| "nn.functional.elu_", |
| "nn.functional.rrelu_", |
| "nn.functional.leaky_relu_", |
| "nn.functional.hardtanh_", |
| "nn.functional.has_torch_function", |
| "nn.functional.has_torch_function_unary", |
| "nn.functional.has_torch_function_variadic", |
| "nn.functional.handle_torch_function", |
| "nn.functional.adaptive_max_pool1d_with_indices", |
| "nn.functional.adaptive_max_pool2d_with_indices", |
| "nn.functional.adaptive_max_pool3d_with_indices", |
| "nn.functional.fractional_max_pool2d_with_indices", |
| "nn.functional.fractional_max_pool3d_with_indices", |
| "is_complex", |
| "grad", |
| "quantize_per_channel", |
| "nn.functional.max_pool2d_with_indices", |
| "nn.functional.max_pool3d_with_indices", |
| "nn.functional.max_pool1d_with_indices", |
| "nn.functional.celu_", |
| "nn.functional.grad", |
| "nn.functional.relu_", |
| "nn.functional.boolean_dispatch", |
| "nn.functional.assert_int_or_pair", |
| "fft", # is namespace |
| } |
| ) |
| |
| torch_ops = top_ops.top_torch |
| nn_fn_ops = top_ops.get_nn_functional_top_list() |
| torch_ops = [op for op in torch_ops if op[0] not in denylist] |
| nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist] |
| |
| ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold] |
| |
| # Now, sort by priority |
| ops.sort(reverse=True, key=lambda op: op[1]) |
| if not with_counts: |
| ops = [op[0] for op in ops] |
| return ops |
| |
| |
| def get_ops_percentage(torch_threshold, nn_fn_threshold): |
| data = top_ops.top_torch + top_ops.get_nn_functional_top_list() |
| |
| def get_num_usages(opname): |
| # Ignore this, this is heavily inflated |
| if opname == "t": |
| return 0 |
| result = [op[1] for op in data if op[0] == opname] |
| assert len(result) == 1 |
| return result[0] |
| |
| # get all operators that are not in the denylist |
| all_ops = get_top_ops(999999, 999999) |
| total_op_usages = sum(get_num_usages(op) for op in all_ops) |
| |
| # get subset of all operators |
| subset_ops = get_top_ops(torch_threshold, nn_fn_threshold) |
| subset_op_usages = sum(get_num_usages(op) for op in subset_ops) |
| return subset_op_usages / total_op_usages |
| |
| |
| def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0): |
| ops = get_top_ops(torch_threshold, nn_fn_threshold) |
| |
| ops_with_opinfo = [] |
| for op in op_db: |
| ops_with_opinfo.append(op.name) |
| ops_with_opinfo.extend([op.name for op in op.aliases]) |
| ops_with_opinfo = set(ops_with_opinfo) |
| |
| result = [op for op in ops if op not in ops_with_opinfo] |
| result = [op for op in result if op not in denylist] |
| result = [op for op in result if op not in factory_fns] |
| return result |
| |
| |
| def get_covered_ops(ops_list, invert=False): |
| ops_covered_by_opinfo = get_ops_covered_by_opinfos() |
| overridable_outplace_ops = ops_list |
| results = {} |
| for key, op in overridable_outplace_ops.items(): |
| cond = op in ops_covered_by_opinfo |
| if invert: |
| cond = not cond |
| if cond: |
| results[key] = op |
| return results |
| |
| |
| class Status(Enum): |
| Correct = 0 |
| Fast = 1 |
| |
| |
| tests = { |
| "test_vmap_exhaustive", |
| "test_op_has_batch_rule", |
| "test_vjp", |
| "test_vmapvjp", |
| "test_vmapvjp_has_batch_rule", |
| "test_jvp", |
| "test_vmapjvp", |
| } |
| |
| |
| def is_decorateinfo_skip_or_xfail(decorateinfo): |
| assert len(decorateinfo.decorators) == 1 |
| actual_decorator = decorateinfo.decorators[0] |
| if isinstance(actual_decorator, toleranceOverride): |
| return False |
| if actual_decorator == unittest.expectedFailure: |
| return True |
| # Assume the rest are skips |
| return True |
| |
| |
| def get_all_tested_ops(): |
| overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() |
| op_to_opinfo = get_ops_covered_by_opinfos() |
| result = set({}) |
| for op in get_covered_ops(overridable_outplace_we_care_about).values(): |
| opinfos = op_to_opinfo[op] |
| for opinfo in opinfos: |
| result.add(opinfo.name) |
| return result |
| |
| |
| def get_skipped_or_xfailed_ops_for(test_name): |
| overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() |
| op_to_opinfo = get_ops_covered_by_opinfos() |
| result = set({}) |
| for op in get_covered_ops(overridable_outplace_we_care_about).values(): |
| opinfos = op_to_opinfo[op] |
| for opinfo in opinfos: |
| for decorator in opinfo.decorators: |
| if not hasattr(decorator, "test_name"): |
| continue |
| if decorator.test_name != test_name: |
| continue |
| if is_decorateinfo_skip_or_xfail(decorator): |
| result.add(opinfo.name) |
| return result |
| |
| |
| def get_statuses(for_subset=None, invert=False): |
| overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() |
| if for_subset is not None: |
| overridable_outplace_we_care_about = { |
| k: v |
| for k, v in overridable_outplace_we_care_about.items() |
| # Removes "torch." |
| if k[6:] in for_subset |
| } |
| op_to_opinfo = get_ops_covered_by_opinfos() |
| result = {} |
| _ = get_covered_ops(overridable_outplace_we_care_about) |
| |
| def get_covered_tests(op): |
| opinfos = op_to_opinfo[op] |
| result = copy.deepcopy(tests) |
| for opinfo in opinfos: |
| for decorator in opinfo.decorators: |
| if not hasattr(decorator, "test_name"): |
| continue |
| if decorator.test_name in tests and decorator.test_name in result: |
| result.remove(decorator.test_name) |
| return result |
| |
| def get_all_aliases(op): |
| opinfos = op_to_opinfo[op] |
| result = [] |
| for opinfo in opinfos: |
| result.append(opinfo.name) |
| result.extend(opinfo.aliases) |
| return set(result) |
| |
| for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): |
| successful_tests = get_covered_tests(op) |
| failed_tests = tests - successful_tests |
| result[name] = failed_tests if invert else successful_tests |
| return result |
| |
| |
| def transpose_statuses(for_subset=None, invert=False): |
| statuses = get_statuses(for_subset, invert=invert) |
| result = {} |
| for test in tests: |
| result[test] = set({}) |
| for op, supported in statuses.items(): |
| for test in supported: |
| result[test].add(op) |
| return result |
| |
| |
| overridable_apis = get_public_overridable_apis() |
| |
| overridable_ops = get_public_overridable_ops() |
| |
| overridable_outplace_ops = get_public_overridable_outplace_ops() |
| |
| overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() |
| |
| tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about) |
| untested_overridable_outplace_ops = get_covered_ops( |
| overridable_outplace_we_care_about, invert=True |
| ) |
| |
| # print("List of OpInfos we need:") |
| # for key in untested_overridable_outplace_ops.keys(): |
| # print(key) |
| # print("-" * 80) |
| # print("") |
| |
| print(f"Overridable public APIs: {len(overridable_apis)}") |
| print(f"Overridable public ops: {len(overridable_ops)}") |
| print(f"Overridable public outplace ops: {len(overridable_outplace_ops)}") |
| print( |
| f"Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}" |
| ) |
| print( |
| f"OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}" |
| ) |
| |
| |
| def remove_torch(name): |
| assert name[:6] == "torch." |
| return name[6:] |
| |
| |
| def get_list_of_all_tests(): |
| all_tests = list(tested_overridable_outplace_ops.keys()) |
| return {remove_torch(test) for test in all_tests} |
| |
| |
| mytest = { |
| "test_vmap_exhaustive", |
| "test_op_has_batch_rule", |
| "test_vjp", |
| "test_vmapvjp", |
| "test_vmapvjp_has_batch_rule", |
| } |
| |
| print("*" * 80) |
| all_tests = get_list_of_all_tests() |
| for test in mytest: |
| result = get_skipped_or_xfailed_ops_for(test) |
| diff = len(all_tests - result) |
| print(f"{test}: {diff}") |
| |
| |
| def get_jvp_coverage(subset=None): |
| # - number that support autograd |
| # - number that support forward_ad (in pytorch core) |
| # - number that support functorch.jvp |
| op_to_opinfo = get_ops_covered_by_opinfos() |
| ops_dct = tested_overridable_outplace_ops |
| if subset is not None: |
| ops_dct = { |
| name: op for name, op in ops_dct.items() if remove_torch(name) in subset |
| } |
| supports_autograd_ops_dct = { |
| name: op_to_opinfo[fn] |
| for name, fn in ops_dct.items() |
| if op_to_opinfo[fn][0].supports_autograd |
| } |
| supports_forwardad_ops_dct = { |
| name: op_to_opinfo[fn] |
| for name, fn in ops_dct.items() |
| if op_to_opinfo[fn][0].supports_forward_ad |
| } |
| |
| ops = {remove_torch(test) for test in list(ops_dct.keys())} |
| supports_autograd = { |
| remove_torch(test) for test in list(supports_autograd_ops_dct.keys()) |
| } |
| supports_forward_ad = { |
| remove_torch(test) for test in list(supports_forwardad_ops_dct.keys()) |
| } |
| assert supports_forward_ad.issubset(supports_autograd) |
| assert supports_autograd.issubset(ops) |
| |
| failed_ops = get_skipped_or_xfailed_ops_for("test_jvp") |
| |
| coverage = len(supports_forward_ad - failed_ops) |
| no_forward_ad = len(supports_autograd) - len(supports_forward_ad) |
| print(f"test_jvp, {coverage}, {no_forward_ad}, {len(ops)}") |
| |
| |
| get_jvp_coverage() |
| get_jvp_coverage(get_top_ops(100, 25)) |
| for op in get_top_ops(100, 25): |
| print(op) |
| print("*" * 80) |
| |
| # result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive') |
| # result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule') |
| # result = get_skipped_or_xfailed_ops_for('test_vjp') |
| # result = get_skipped_or_xfailed_ops_for('test_vmapvjp') |
| # result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule') |
| # import pdb; pdb.set_trace() |
| |
| statuses = transpose_statuses() |
| for test in tests: |
| print(f"{test} coverage {len(statuses[test])}") |
| |
| method_only_ops = get_method_only_ops_we_care_about() |
| # for op in method_only_ops: |
| # print(f' {op},') |
| |
| top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25) |
| print("=" * 80) |
| for op in top_ops_not_covered_by_opinfo: |
| print(f"{op}, {top_ops.usage_count[op]}") |
| |
| # print("top ops not covered by opinfo: ") |
| # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50) |
| # for op in top_ops_not_covered_by_opinfo: |
| # print(f'{op}, {top_ops.usage_count[op]}') |
| |
| # print("top ops not covered by opinfo: ") |
| # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92) |
| # for op in top_ops_not_covered_by_opinfo: |
| # print(f'{op}, {top_ops.usage_count[op]}') |
| |
| # print("top ops not covered by opinfo: ") |
| # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999) |
| # for op in top_ops_not_covered_by_opinfo: |
| # print(f'{op}, {top_ops.usage_count[op]}') |
| |
| |
| def remove_from_set(parent, to_remove): |
| for to_remove_elt in to_remove: |
| if to_remove_elt in parent: |
| parent.remove(to_remove_elt) |
| |
| |
| def print_coverage_info(th=100, nn=25): |
| print("=" * 80) |
| print(f"top {th}, {nn} coverage") |
| statuses = transpose_statuses(get_top_ops(th, nn), invert=True) |
| top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn) |
| |
| # testing problems |
| exemptions = { |
| "torch.nn.functional.dropout", # randomness |
| } |
| |
| # Allowed exemptions |
| vmap_exemptions = { |
| "torch.randn_like", # randomness |
| "torch.rand_like", # randomness |
| "torch.allclose", # number output |
| "torch.unique", # dynamic |
| "torch.nonzero", # dynamic |
| "torch.masked_select", # dynamic |
| "torch.prod", # dynamic (backward) |
| "torch.norm", # norm with nuc is not commonly used; we support the other cases. |
| "torch.svd", # There isn't a bug, it is just nondeterministic so we can't test it. |
| "torch.nn.functional.embedding", # We support everything except the sparse option. |
| } |
| remove_from_set(statuses["test_vmap_exhaustive"], vmap_exemptions) |
| remove_from_set(statuses["test_vmapvjp"], vmap_exemptions) |
| remove_from_set(statuses["test_vmapvjp_has_batch_rule"], vmap_exemptions) |
| remove_from_set(statuses["test_op_has_batch_rule"], vmap_exemptions) |
| remove_from_set(statuses["test_vmapjvp"], vmap_exemptions) |
| for test in tests: |
| remove_from_set(statuses[test], exemptions) |
| |
| print(f"total ops in set: {th + nn}") |
| print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}") |
| for test in tests: |
| if test in {"test_jvp", "test_vmapjvp"}: |
| continue |
| print(f"{test} failing coverage {len(statuses[test])}") |
| |
| # We don't care about these yet |
| del statuses["test_jvp"] |
| del statuses["test_vmapjvp"] |
| |
| pprint.pprint(statuses) |
| |
| |
| def get_name_to_opinfo_map(): |
| dct = {} |
| for op in op_db + additional_op_db: |
| |
| def add(name, op): |
| if name not in dct: |
| dct[name] = [] |
| dct[name].append(op) |
| |
| add(op.name, op) |
| for alias in op.aliases: |
| add(alias.name, op) |
| return dct |
| |
| |
| NAME_TO_OPINFO = get_name_to_opinfo_map() |
| |
| |
| class Support(enum.Enum): |
| NO = 0 |
| YES = 1 |
| UNKNOWN = 2 |
| |
| |
| FACTORY_FNS = { |
| "tensor", |
| "zeros", |
| "ones", |
| "randn", |
| "arange", |
| "rand", |
| "empty", |
| "range", |
| "full", |
| "randperm", |
| "eye", |
| "randint", |
| "linspace", |
| "logspace", |
| } |
| |
| VJP_EXEMPTIONS = { |
| "nn.functional.dropout", # not actually problem, randomness testing artifact |
| "nn.functional.dropout2d", # not actually problem, randomness testing artifact |
| "nn.functional.rrelu", # not actually problem, randomness testing artifact |
| "bernoulli", # not actually problem, randomness testing artifact |
| "normal", # not actually problem, randomness testing artifact |
| } |
| |
| VMAP_EXEMPTIONS = { |
| "randn_like", # randomness |
| "rand_like", # randomness |
| "allclose", # number output |
| "unique", # dynamic |
| "nonzero", # dynamic |
| "masked_select", # dynamic |
| "prod", # dynamic (backward) |
| "norm", # norm with nuc is not commonly used; we support the other cases. |
| "svd", # There isn't a bug, it is just nondeterministic so we can't test it. |
| "nn.functional.embedding", # We support everything except the sparse option. |
| "nn.functional.dropout", # randomness |
| "nn.functional.dropout2d", # randomness |
| "bernoulli", # randomness |
| "multinomial", # randomness |
| "normal", # randomness |
| } |
| |
| JVP_EXEMPTIONS = { |
| "nn.functional.dropout", # not actually problem, randomness testing artifact |
| "nn.functional.dropout2d", # not actually problem, randomness testing artifact |
| "nn.functional.rrelu", # not actually problem, randomness testing artifact |
| "normal", # not actually problem, randomness testing artifact |
| "bernoulli", # not actually problem, randomness testing artifact |
| } |
| |
| |
| class Operator: |
| def __init__(self, name): |
| self.name = name |
| self.opinfos = NAME_TO_OPINFO.get(name, None) |
| assert self.opinfos is None or len(self.opinfos) > 0 |
| |
| def has_opinfo(self): |
| return self.opinfos is not None |
| |
| def __repr__(self): |
| return f'Operator("{self.name}")' |
| |
| def __hash__(self): |
| return hash(self.name) |
| |
| def no_opinfos_skip_test(self, test_name): |
| """Returns NO if any opinfos have a skip or xfail for the test""" |
| if not self.has_opinfo(): |
| return Support.UNKNOWN |
| for opinfo in self.opinfos: |
| for decorator in opinfo.decorators: |
| if not hasattr(decorator, "test_name"): |
| continue |
| if decorator.test_name != test_name: |
| continue |
| if is_decorateinfo_skip_or_xfail(decorator): |
| return Support.NO |
| return Support.YES |
| |
| def any_opinfo_attr(self, attr): |
| if not self.has_opinfo(): |
| raise RuntimeError |
| return any(getattr(opinfo, attr) for opinfo in self.opinfos) |
| |
| def all_opinfo_attr(self, attr): |
| if not self.has_opinfo(): |
| raise RuntimeError |
| return all(getattr(opinfo, attr) for opinfo in self.opinfos) |
| |
| def supports_vjp(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in VJP_EXEMPTIONS: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_vjp") |
| |
| def supports_vmap(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in VMAP_EXEMPTIONS: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_vmap_exhaustive") |
| |
| def supports_fast_vmap(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in VMAP_EXEMPTIONS: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_op_has_batch_rule") |
| |
| def supports_vmapvjp(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in VMAP_EXEMPTIONS: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_vmapvjp") |
| |
| def supports_fast_vmapvjp(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in VMAP_EXEMPTIONS: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_vmapvjp_has_batch_rule") |
| |
| def supports_jvp(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| if self.name in JVP_EXEMPTIONS: |
| return Support.YES |
| if not self.has_opinfo(): |
| return Support.UNKNOWN |
| if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( |
| "supports_forward_ad" |
| ): |
| return Support.NO |
| return self.no_opinfos_skip_test("test_jvp") |
| |
| def supports_jvpvjp(self): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| exemptions = { |
| # we have support (see OpInfo), testing artifact |
| "nn.functional.dropout2d", |
| "nn.functional.dropout", |
| # exception: we dont even support double backward for this |
| "nn.functional.hardswish", |
| "bernoulli", # this isn't differentiable |
| "normal", # not differentiable |
| } |
| if self.name in exemptions: |
| return Support.YES |
| return self.no_opinfos_skip_test("test_jvpvjp") |
| |
| def _supports_vmapjvp_base(self, test): |
| if self.name in FACTORY_FNS: |
| return Support.YES |
| VMAPJVP_EXEMPTIONS = { |
| "prod", # dynamic (backward) |
| "nn.functional.batch_norm", # testing problem |
| "normal", # not actually problem, randomness testing artifact |
| "bernoulli", # not actually problem, randomness testing artifact |
| "nn.functional.dropout2d", # not actually problem, randomness testing artifact |
| "nn.functional.dropout", # not actually problem, randomness testing artifact |
| # Not a problem. |
| # It's just that the max_norm testing mutates inputs... |
| # (we have our own functorch variant of the OpInfo without max_norm) |
| "nn.functional.embedding", |
| } |
| if self.name in VMAPJVP_EXEMPTIONS: |
| return Support.YES |
| if not self.has_opinfo(): |
| return Support.UNKNOWN |
| if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( |
| "supports_forward_ad" |
| ): |
| return Support.NO |
| return self.no_opinfos_skip_test(test) |
| |
| def supports_vmapjvp(self): |
| return self._supports_vmapjvp_base("test_vmapjvpall") |
| |
| def supports_fast_vmapjvp(self): |
| return self._supports_vmapjvp_base("test_vmapjvpall_has_batch_rule") |
| |
| |
| class OperatorSet: |
| def __init__(self, operators): |
| self.data = set(operators) |
| |
| @classmethod |
| def from_names(cls, names): |
| return OperatorSet([Operator(name) for name in names]) |
| |
| @classmethod |
| def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold): |
| names = get_top_ops(torch_threshold, nn_fn_threshold) |
| return cls.from_names(names) |
| |
| @classmethod |
| def from_top125(cls): |
| return cls.from_top_ops_threshold(100, 25) |
| |
| @classmethod |
| def from_top160(cls): |
| return cls.from_top_ops_threshold(107, 53) |
| |
| @classmethod |
| def all(cls): |
| dct = get_public_overridable_outplace_we_care_about() |
| names = dct.keys() |
| names_sanitized = [] |
| for n in names: |
| torch_tensor = "torch.Tensor." |
| torch_dot = "torch." |
| if n.startswith(torch_tensor): |
| names_sanitized.append(n[len(torch_tensor) :]) |
| elif n.startswith(torch_dot): |
| names_sanitized.append(n[len(torch_dot) :]) |
| else: |
| raise AssertionError |
| return cls.from_names(names_sanitized) |
| |
| def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)): |
| result = {} |
| for key in filter: |
| result[key] = set() |
| for op in self.data: |
| support_status = operator_method(op) |
| if support_status in filter: |
| result[support_status].add(op) |
| return result |
| |
| def summary(self): |
| checks = [ |
| "supports_vjp", |
| "supports_vmap", |
| "supports_fast_vmap", |
| "supports_vmapvjp", |
| "supports_fast_vmapvjp", |
| "supports_jvp", |
| "supports_vmapjvp", |
| "supports_fast_vmapjvp", |
| "supports_jvpvjp", |
| ] |
| result = ["test, yes, no, unknown"] |
| for check in checks: |
| accessor = getattr(Operator, check) |
| all_results = self.query(accessor) |
| yes_amt = len(all_results[Support.YES]) |
| no_amt = len(all_results[Support.NO]) |
| unknown_amt = len(all_results[Support.UNKNOWN]) |
| result.append(f"{check}, {yes_amt}, {no_amt}, {unknown_amt}") |
| return "\n".join(result) |
| |
| |
| opset = OperatorSet.all() |
| has_no_opinfo = opset.query(Operator.has_opinfo, (False,)) |
| |
| print("=" * 30 + " Summary " + "=" * 30) |
| print(f"% of usages on github: {get_ops_percentage(99999, 99999)}") |
| print(opset.summary()) |
| |
| # sanity checks |
| result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| |
| print("=" * 30 + " Top 60 Summary " + "=" * 30) |
| print(f"% of usages on github: {get_ops_percentage(35, 25)}") |
| opset = OperatorSet.from_top_ops_threshold(35, 25) |
| # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) |
| # kpprint.pprint(result) |
| # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # pprint.pprint(result) |
| print(opset.summary()) |
| |
| print("=" * 30 + " Top 125 Summary " + "=" * 30) |
| print(f"% of usages on github: {get_ops_percentage(100, 25)}") |
| opset = OperatorSet.from_top125() |
| # result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| print("supports_vjp") |
| result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) |
| pprint.pprint(result) |
| print("supports_jvp") |
| result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) |
| pprint.pprint(result) |
| print("supports_vmapjvp") |
| result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) |
| pprint.pprint(result) |
| print("supports_jvpvjp") |
| result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) |
| pprint.pprint(result) |
| # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # pprint.pprint(result) |
| print(opset.summary()) |
| |
| # print("=" * 30 + " Top 160 Summary " + "=" * 30) |
| # opset = OperatorSet.from_top160() |
| # result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) |
| # pprint.pprint(result) |
| # print(opset.summary()) |
| |
| # Print list of everything in order |
| # all_ops = get_top_ops(999999, 999999, with_counts=True) |
| # for op, count in all_ops: |
| # print(f'{op}, {count}') |