[functorch] Add tests for vmapjvpall batch rule coverage (pytorch/functorch#602)

diff --git a/functorch/test/discover_coverage.py b/functorch/test/discover_coverage.py
index fb633ea..1c26981 100644
--- a/functorch/test/discover_coverage.py
+++ b/functorch/test/discover_coverage.py
@@ -4,6 +4,7 @@
 from enum import Enum
 import functorch._src.top_operators_github_usage as top_ops
 import pprint
+import unittest
 
 # Importing these files make modifications to the op_db that we need
 import test_ops  # noqa: F401
@@ -276,6 +277,46 @@
 }
 
 
+def is_decorateinfo_skip_or_xfail(decorateinfo):
+    assert len(decorateinfo.decorators) == 1
+    actual_decorator = decorateinfo.decorators[0]
+    if actual_decorator == unittest.skip("Skipped!"):
+        return True
+    if actual_decorator == unittest.expectedFailure:
+        return True
+    return False
+
+
+def get_all_tested_ops():
+    overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
+    op_to_opinfo = get_ops_covered_by_opinfos()
+    result = set({})
+    for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
+        opinfos = op_to_opinfo[op]
+        for opinfo in opinfos:
+            result.add(opinfo.name)
+    return result
+
+
+def get_skipped_or_xfailed_ops_for(test_name):
+    overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
+    op_to_opinfo = get_ops_covered_by_opinfos()
+    result = set({})
+    for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
+        opinfos = op_to_opinfo[op]
+        for opinfo in opinfos:
+            for decorator in opinfo.decorators:
+                if not hasattr(decorator, 'test_name'):
+                    continue
+                if decorator.test_name != test_name:
+                    continue
+                if is_decorateinfo_skip_or_xfail(decorator):
+                    result.add(opinfo.name)
+    return result
+
+
+# import pdb; pdb.set_trace()
+
 def get_statuses(for_subset=None, invert=False):
     overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
     if for_subset is not None:
@@ -350,6 +391,74 @@
 print(f'OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}')
 
 
+def remove_torch(name):
+    assert name[:6] == 'torch.'
+    return name[6:]
+
+
+def get_list_of_all_tests():
+    all_tests = list(tested_overridable_outplace_ops.keys())
+    return set([remove_torch(test) for test in all_tests])
+
+
+mytest = {
+    'test_vmap_exhaustive',
+    'test_op_has_batch_rule',
+    'test_vjp',
+    'test_vmapvjp',
+    'test_vmapvjp_has_batch_rule',
+}
+
+print('*' * 80)
+all_tests = get_list_of_all_tests()
+for test in mytest:
+    result = get_skipped_or_xfailed_ops_for(test)
+    diff = len(all_tests - result)
+    print(f'{test}: {diff}')
+
+
+def get_jvp_coverage(subset=None):
+    # - number that support autograd
+    # - number that support forward_ad (in pytorch core)
+    # - number that support functorch.jvp
+    op_to_opinfo = get_ops_covered_by_opinfos()
+    ops_dct = tested_overridable_outplace_ops
+    if subset is not None:
+        ops_dct = {name: op for name, op in ops_dct.items()
+                   if remove_torch(name) in subset}
+    supports_autograd_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
+                                 if op_to_opinfo[fn][0].supports_autograd}
+    supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
+                                  if op_to_opinfo[fn][0].supports_forward_ad}
+
+    ops = set([remove_torch(test) for test in list(ops_dct.keys())])
+    supports_autograd = set([remove_torch(test)
+                             for test in list(supports_autograd_ops_dct.keys())])
+    supports_forward_ad = set([remove_torch(test)
+                               for test in list(supports_forwardad_ops_dct.keys())])
+    assert supports_forward_ad.issubset(supports_autograd)
+    assert supports_autograd.issubset(ops)
+
+    failed_ops = get_skipped_or_xfailed_ops_for('test_jvp')
+
+    coverage = len(supports_forward_ad - failed_ops)
+    no_forward_ad = len(supports_autograd) - len(supports_forward_ad)
+    print(f'test_jvp, {coverage}, {no_forward_ad}, {len(ops)}')
+
+
+get_jvp_coverage()
+get_jvp_coverage(get_top_ops(100, 25))
+for op in get_top_ops(100, 25):
+    print(op)
+print('*' * 80)
+
+# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive')
+# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule')
+# result = get_skipped_or_xfailed_ops_for('test_vjp')
+# result = get_skipped_or_xfailed_ops_for('test_vmapvjp')
+# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule')
+# import pdb; pdb.set_trace()
+
 statuses = transpose_statuses()
 for test in tests:
     print(f'{test} coverage {len(statuses[test])}')
@@ -435,3 +544,41 @@
 # print_coverage_info(200, 50)
 
 # pprint.pprint(get_top_ops(100, 25))
+
+dct = {}
+for op in op_db:
+    def add(name, op):
+        if name not in dct:
+            dct[name] = []
+        dct[name].append(op)
+    add(op.name, op)
+    for alias in op.aliases:
+        add(alias.name, op)
+
+top_ops_125 = set(get_top_ops(100, 25))
+dct_keys = set(dct.keys())
+
+# only has 110, but that's OK
+dct_125 = {k: v for k, v in dct.items() if k in top_ops_125}
+failed_ops = get_skipped_or_xfailed_ops_for('test_vmapjvpall_has_batch_rule')
+failed_ops = set([k for k in failed_ops if k in dct_125])
+supports_bwd = {k: v for k, v in dct_125.items()
+                if any(opinfo.supports_autograd for opinfo in v)}
+supports_fwd = {k: v for k, v in dct_125.items()
+                if all(opinfo.supports_forward_ad for opinfo in v)}
+supports_fwd = set(supports_fwd.keys())
+supports_bwd = set(supports_bwd.keys())
+assert set(supports_fwd).issubset(set(supports_bwd))
+supports_bwd_but_not_fwd = set(supports_bwd) - set(supports_fwd)
+unsupported = failed_ops.union(supports_bwd_but_not_fwd)
+supported = set(dct_125.keys()) - unsupported
+
+print(len(supported) + len(unsupported))
+
+print("&" * 80)
+for x in supported:
+    print(x)
+
+print("&" * 80)
+for x in unsupported:
+    print(x)
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index ee17757..9aa719d 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -705,12 +705,7 @@
             for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)):
                 self.assertEqual(loop_out, batched_out)
 
-    @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
-    @opsToleranceOverride('TestOperators', 'test_vjp', (
-        tol1('nn.functional.conv_transpose3d',
-             {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
-    ))
-    @skipOps('TestOperators', 'test_vmapjvpall', {
+    vmapjvpall_fail = {
         skip('nn.functional.dropout'),  # randomness
         skip('nn.functional.rrelu'),  # randomness
         xfail('nn.functional.fractional_max_pool2d'),  # Cannot access data pointer of Tensor that doesn't have storage
@@ -760,7 +755,14 @@
         # Some kind of issue with unsymmetric tangent type
         # Runtime Error: The tangent part of the matrix A should also be symmetric.
         xfail('linalg.eigh'),
-    })
+    }
+
+    @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
+    @opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
+        tol1('nn.functional.conv_transpose3d',
+             {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
+    ))
+    @skipOps('TestOperators', 'test_vmapjvpall', vmapjvpall_fail)
     @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
     # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
     # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
@@ -785,6 +787,61 @@
             for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
                 self.assertEqual(loop_out, batched_out)
 
+    @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
+    @skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
+        xfail('linalg.solve_triangular'),
+        xfail('nn.functional.huber_loss'),
+        xfail('nn.functional.poisson_nll_loss'),
+        xfail('lu'),
+        xfail('cumprod'),
+        xfail('lu_solve'),
+        xfail('linalg.lstsq', 'grad_oriented'),
+        xfail('cross'),
+        xfail('qr'),
+        xfail('linalg.pinv'),
+        xfail('masked_fill'),
+        xfail('copysign'),
+        xfail('linalg.solve'),
+        xfail('linalg.eig'),
+        xfail('complex'),
+        xfail('linalg.pinv', 'hermitian'),
+        xfail('pinverse'),
+        skip('_masked.mean'),  # ???
+        xfail('linalg.cholesky_ex'),
+        xfail('masked_scatter'),
+        xfail('index_fill'),
+        xfail('take'),
+        xfail('linalg.eigvals'),
+        xfail('linalg.qr'),
+        xfail('linalg.tensorsolve'),
+        xfail('nn.functional.max_pool3d'),
+        xfail('vdot'),
+        xfail('linalg.cross'),
+    }))
+    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
+    def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
+        if is_inplace(op, op.get_op()):
+            # TODO: test in-place
+            self.skipTest("Skipped! NYI: inplace-testing not supported.")
+            return
+
+        samples = op.sample_inputs(device, dtype, requires_grad=False)
+
+        if not op.supports_forward_ad:
+            self.skipTest("Skipped! Forward AD not supported.")
+            return
+
+        def test():
+            for sample in samples:
+                arg_values = [sample.input] + list(sample.args)
+                kwarg_values = sample.kwargs
+                args = tuple([*arg_values, *kwarg_values])
+                fn, args = get_jvp_variant_primals_tangents(op, sample)
+                for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
+                        fn, args, {}, opinfo=op, compute_loop_out=False):
+                    pass
+        check_vmap_fallback(self, test, op, dry_run=False)
+
     @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
     @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
     @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({