[functorch] Add exhaustive testing of vmap autograd composability (pytorch/functorch#851)

* refactor to make simpler based on comments

* cleanup

* more failing tests

* fix test failures

* more test failures

* update xfails
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 1d2539f..2482fbc 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -209,35 +209,6 @@
     return wrapped, tuple(flat_args + flat_cotangents)
 
 
-# returns a new function g(*args, *cotangents)
-# that computes vjps and (*args, cotangents) using torch.autograd.grad
-def get_autograd_fn_and_args_with_cotangents(f, sample, cotangents):
-    args = tuple([sample.input] + list(sample.args))
-    kwargs = sample.kwargs
-    flat_args, args_spec = tree_flatten(args)
-    flat_cotangents, cotangents_spec = tree_flatten(cotangents)
-
-    @functools.wraps(f)
-    def wrapped(*args):
-        assert len(args) == len(flat_args) + len(flat_cotangents)
-        actual_args = args[:len(flat_args)]
-        cotangents = args[len(flat_args):]
-        actual_args = tree_unflatten(actual_args, args_spec)
-        cotangents = tree_unflatten(cotangents, cotangents_spec)
-
-        fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
-                                                 flat_args,
-                                                 sample.output_process_fn_grad)
-        out = fn(*primals)
-        diff_wrt = tuple(primal for primal in primals if (primal.requires_grad or primal.grad_fn is not None))
-        if diff_wrt:
-            return torch.autograd.grad(out, diff_wrt, grad_outputs=cotangents)
-        else:
-            return (torch.ones(()),)  # uuugh hack...this will need to be more generic
-
-    return wrapped, tuple(flat_args + flat_cotangents)
-
-
 # Returns a new function g(*args, *cotangents) that computes vjps and
 # sample (*args, *cotangents)
 def get_vjpfull_variant(f, sample):
@@ -1310,24 +1281,73 @@
                 cotangents = torch.randn_like(result, device=device)
                 self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
 
-    @ops(tuple(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db)),
-         allowed_dtypes=(torch.float32, torch.double))  # TODO: generalize
-    def test_group_norm_backward(self, device, dtype, op):
-        # hacky, only works since no group norm inputs can be scalars
-        def was_skipped_from_batched_tensors(batched_out, batch_size):
-            return batched_out.shape == (batch_size,) and all(tuple(e == 1 for e in batched_out))
+    @skipOps('TestOperators', 'test_vmap_autograd_grad', {
+        # call inplace functions
+        xfail('linalg.householder_product'),  # inplace
+        xfail('matrix_exp'),  # inplace
+        xfail('take'),  # inplace
+
+        xfail('linalg.eig'),  # all close?
+        # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
+        xfail('masked_select'),
+        xfail('nn.functional.max_unpool2d', 'grad'),  # contiguous call
+        xfail('nn.functional.max_unpool2d'),  # contiguous call
+        xfail('to_sparse'),  # dispatch key issue
+
+        # numerical inconsistencies, look like bugs
+        skip('ldexp', dtypes=(torch.float32,), device_type='cpu'),  # fails on all but mac
+        skip('__rmatmul__', dtypes=(torch.float32,), device_type='cpu'),  # fails on all but windows
+        skip('matmul', dtypes=(torch.float32,), device_type='cpu'),  # fails on all but windows
+        skip('nn.functional.conv_transpose3d', dtypes=(torch.float32,)),  # only fails on cpu only linux
+        skip('nn.functional.layer_norm', dtypes=(torch.float32,), device_type='cpu'),  # fails on windows
+        skip('linalg.lu_factor', dtypes=(torch.float32,), device_type='cuda'),  # fails on all but windows
+        skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'),  # fails on all but windows
+    })
+    @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
+    def test_vmap_autograd_grad(self, device, dtype, op):
+        def is_differentiable(inp):
+            return isinstance(inp, Tensor) and (inp.grad_fn is not None or inp.requires_grad)
+
+        def get_flat_differentiable(pytree):
+            flattened = tree_flatten(pytree)[0]
+            return tuple(i for i in flattened if is_differentiable(i))
+
+        def get_differentiable_linked(list1, list2):
+            paired_list = zip(list1, list2)
+            paired_list = tuple((first, second) for (first, second) in paired_list if is_differentiable(first))
+            return zip(*paired_list)
+
+        def filter_none(out):
+            flattened = tree_flatten(out)[0]
+            return tuple(o for o in flattened if o is not None)
+
+        if not op.supports_autograd:
+            self.skipTest("Skipped! Autograd not supported.")
+            return
 
         sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
 
         for sample_input in sample_inputs:
-            cotangents = get_sample_cotangents(op, sample_input)
-            f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
+            fn, primals = normalize_op_input_output(op, sample_input)
+            out = fn(*primals)
+            cotangents = tree_map(torch.randn_like, out)
+
+            def compute_grad(cotangents):
+                out_flattened = out
+                cotangents_flattened = cotangents
+                if not isinstance(out_flattened, torch.Tensor):
+                    out_flattened = tree_flatten(out)[0]
+                    cotangents_flattened = tree_flatten(cotangents)[0]
+                    out_flattened, cotangents_flattened = get_differentiable_linked(out_flattened, cotangents_flattened)
+
+                return filter_none(
+                    torch.autograd.grad(out_flattened, get_flat_differentiable(primals), cotangents_flattened,
+                                        retain_graph=True, allow_unused=True))
+
             is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
             generator = get_fallback_and_vmap_exhaustive(
-                f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
+                compute_grad, (cotangents,), {}, is_batch_norm_and_training=is_batch_norm_and_training)
             for loop_out, batched_out in generator:
-                if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
-                    continue  # we weren't able to use the batched tensor in autograd.grad
                 self.assertEqual(loop_out, batched_out)