[functorch] Add tests for vmapjvpall batch rule coverage (pytorch/functorch#602)
diff --git a/functorch/test/discover_coverage.py b/functorch/test/discover_coverage.py
index fb633ea..1c26981 100644
--- a/functorch/test/discover_coverage.py
+++ b/functorch/test/discover_coverage.py
@@ -4,6 +4,7 @@
from enum import Enum
import functorch._src.top_operators_github_usage as top_ops
import pprint
+import unittest
# Importing these files make modifications to the op_db that we need
import test_ops # noqa: F401
@@ -276,6 +277,46 @@
}
+def is_decorateinfo_skip_or_xfail(decorateinfo):
+ assert len(decorateinfo.decorators) == 1
+ actual_decorator = decorateinfo.decorators[0]
+ if actual_decorator == unittest.skip("Skipped!"):
+ return True
+ if actual_decorator == unittest.expectedFailure:
+ return True
+ return False
+
+
+def get_all_tested_ops():
+ overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
+ op_to_opinfo = get_ops_covered_by_opinfos()
+ result = set({})
+ for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
+ opinfos = op_to_opinfo[op]
+ for opinfo in opinfos:
+ result.add(opinfo.name)
+ return result
+
+
+def get_skipped_or_xfailed_ops_for(test_name):
+ overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
+ op_to_opinfo = get_ops_covered_by_opinfos()
+ result = set({})
+ for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
+ opinfos = op_to_opinfo[op]
+ for opinfo in opinfos:
+ for decorator in opinfo.decorators:
+ if not hasattr(decorator, 'test_name'):
+ continue
+ if decorator.test_name != test_name:
+ continue
+ if is_decorateinfo_skip_or_xfail(decorator):
+ result.add(opinfo.name)
+ return result
+
+
+# import pdb; pdb.set_trace()
+
def get_statuses(for_subset=None, invert=False):
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
if for_subset is not None:
@@ -350,6 +391,74 @@
print(f'OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}')
+def remove_torch(name):
+ assert name[:6] == 'torch.'
+ return name[6:]
+
+
+def get_list_of_all_tests():
+ all_tests = list(tested_overridable_outplace_ops.keys())
+ return set([remove_torch(test) for test in all_tests])
+
+
+mytest = {
+ 'test_vmap_exhaustive',
+ 'test_op_has_batch_rule',
+ 'test_vjp',
+ 'test_vmapvjp',
+ 'test_vmapvjp_has_batch_rule',
+}
+
+print('*' * 80)
+all_tests = get_list_of_all_tests()
+for test in mytest:
+ result = get_skipped_or_xfailed_ops_for(test)
+ diff = len(all_tests - result)
+ print(f'{test}: {diff}')
+
+
+def get_jvp_coverage(subset=None):
+ # - number that support autograd
+ # - number that support forward_ad (in pytorch core)
+ # - number that support functorch.jvp
+ op_to_opinfo = get_ops_covered_by_opinfos()
+ ops_dct = tested_overridable_outplace_ops
+ if subset is not None:
+ ops_dct = {name: op for name, op in ops_dct.items()
+ if remove_torch(name) in subset}
+ supports_autograd_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
+ if op_to_opinfo[fn][0].supports_autograd}
+ supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
+ if op_to_opinfo[fn][0].supports_forward_ad}
+
+ ops = set([remove_torch(test) for test in list(ops_dct.keys())])
+ supports_autograd = set([remove_torch(test)
+ for test in list(supports_autograd_ops_dct.keys())])
+ supports_forward_ad = set([remove_torch(test)
+ for test in list(supports_forwardad_ops_dct.keys())])
+ assert supports_forward_ad.issubset(supports_autograd)
+ assert supports_autograd.issubset(ops)
+
+ failed_ops = get_skipped_or_xfailed_ops_for('test_jvp')
+
+ coverage = len(supports_forward_ad - failed_ops)
+ no_forward_ad = len(supports_autograd) - len(supports_forward_ad)
+ print(f'test_jvp, {coverage}, {no_forward_ad}, {len(ops)}')
+
+
+get_jvp_coverage()
+get_jvp_coverage(get_top_ops(100, 25))
+for op in get_top_ops(100, 25):
+ print(op)
+print('*' * 80)
+
+# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive')
+# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule')
+# result = get_skipped_or_xfailed_ops_for('test_vjp')
+# result = get_skipped_or_xfailed_ops_for('test_vmapvjp')
+# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule')
+# import pdb; pdb.set_trace()
+
statuses = transpose_statuses()
for test in tests:
print(f'{test} coverage {len(statuses[test])}')
@@ -435,3 +544,41 @@
# print_coverage_info(200, 50)
# pprint.pprint(get_top_ops(100, 25))
+
+dct = {}
+for op in op_db:
+ def add(name, op):
+ if name not in dct:
+ dct[name] = []
+ dct[name].append(op)
+ add(op.name, op)
+ for alias in op.aliases:
+ add(alias.name, op)
+
+top_ops_125 = set(get_top_ops(100, 25))
+dct_keys = set(dct.keys())
+
+# only has 110, but that's OK
+dct_125 = {k: v for k, v in dct.items() if k in top_ops_125}
+failed_ops = get_skipped_or_xfailed_ops_for('test_vmapjvpall_has_batch_rule')
+failed_ops = set([k for k in failed_ops if k in dct_125])
+supports_bwd = {k: v for k, v in dct_125.items()
+ if any(opinfo.supports_autograd for opinfo in v)}
+supports_fwd = {k: v for k, v in dct_125.items()
+ if all(opinfo.supports_forward_ad for opinfo in v)}
+supports_fwd = set(supports_fwd.keys())
+supports_bwd = set(supports_bwd.keys())
+assert set(supports_fwd).issubset(set(supports_bwd))
+supports_bwd_but_not_fwd = set(supports_bwd) - set(supports_fwd)
+unsupported = failed_ops.union(supports_bwd_but_not_fwd)
+supported = set(dct_125.keys()) - unsupported
+
+print(len(supported) + len(unsupported))
+
+print("&" * 80)
+for x in supported:
+ print(x)
+
+print("&" * 80)
+for x in unsupported:
+ print(x)
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index ee17757..9aa719d 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -705,12 +705,7 @@
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)):
self.assertEqual(loop_out, batched_out)
- @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
- @opsToleranceOverride('TestOperators', 'test_vjp', (
- tol1('nn.functional.conv_transpose3d',
- {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
- ))
- @skipOps('TestOperators', 'test_vmapjvpall', {
+ vmapjvpall_fail = {
skip('nn.functional.dropout'), # randomness
skip('nn.functional.rrelu'), # randomness
xfail('nn.functional.fractional_max_pool2d'), # Cannot access data pointer of Tensor that doesn't have storage
@@ -760,7 +755,14 @@
# Some kind of issue with unsymmetric tangent type
# Runtime Error: The tangent part of the matrix A should also be symmetric.
xfail('linalg.eigh'),
- })
+ }
+
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
+ @opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
+ tol1('nn.functional.conv_transpose3d',
+ {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
+ ))
+ @skipOps('TestOperators', 'test_vmapjvpall', vmapjvpall_fail)
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
# This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
# or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
@@ -785,6 +787,61 @@
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
self.assertEqual(loop_out, batched_out)
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
+ @skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
+ xfail('linalg.solve_triangular'),
+ xfail('nn.functional.huber_loss'),
+ xfail('nn.functional.poisson_nll_loss'),
+ xfail('lu'),
+ xfail('cumprod'),
+ xfail('lu_solve'),
+ xfail('linalg.lstsq', 'grad_oriented'),
+ xfail('cross'),
+ xfail('qr'),
+ xfail('linalg.pinv'),
+ xfail('masked_fill'),
+ xfail('copysign'),
+ xfail('linalg.solve'),
+ xfail('linalg.eig'),
+ xfail('complex'),
+ xfail('linalg.pinv', 'hermitian'),
+ xfail('pinverse'),
+ skip('_masked.mean'), # ???
+ xfail('linalg.cholesky_ex'),
+ xfail('masked_scatter'),
+ xfail('index_fill'),
+ xfail('take'),
+ xfail('linalg.eigvals'),
+ xfail('linalg.qr'),
+ xfail('linalg.tensorsolve'),
+ xfail('nn.functional.max_pool3d'),
+ xfail('vdot'),
+ xfail('linalg.cross'),
+ }))
+ @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
+ def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
+ if is_inplace(op, op.get_op()):
+ # TODO: test in-place
+ self.skipTest("Skipped! NYI: inplace-testing not supported.")
+ return
+
+ samples = op.sample_inputs(device, dtype, requires_grad=False)
+
+ if not op.supports_forward_ad:
+ self.skipTest("Skipped! Forward AD not supported.")
+ return
+
+ def test():
+ for sample in samples:
+ arg_values = [sample.input] + list(sample.args)
+ kwarg_values = sample.kwargs
+ args = tuple([*arg_values, *kwarg_values])
+ fn, args = get_jvp_variant_primals_tangents(op, sample)
+ for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
+ fn, args, {}, opinfo=op, compute_loop_out=False):
+ pass
+ check_vmap_fallback(self, test, op, dry_run=False)
+
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({