[functorch] Add exhaustive testing of vmap autograd composability (pytorch/functorch#851)
* refactor to make simpler based on comments
* cleanup
* more failing tests
* fix test failures
* more test failures
* update xfails
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 1d2539f..2482fbc 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -209,35 +209,6 @@
return wrapped, tuple(flat_args + flat_cotangents)
-# returns a new function g(*args, *cotangents)
-# that computes vjps and (*args, cotangents) using torch.autograd.grad
-def get_autograd_fn_and_args_with_cotangents(f, sample, cotangents):
- args = tuple([sample.input] + list(sample.args))
- kwargs = sample.kwargs
- flat_args, args_spec = tree_flatten(args)
- flat_cotangents, cotangents_spec = tree_flatten(cotangents)
-
- @functools.wraps(f)
- def wrapped(*args):
- assert len(args) == len(flat_args) + len(flat_cotangents)
- actual_args = args[:len(flat_args)]
- cotangents = args[len(flat_args):]
- actual_args = tree_unflatten(actual_args, args_spec)
- cotangents = tree_unflatten(cotangents, cotangents_spec)
-
- fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
- flat_args,
- sample.output_process_fn_grad)
- out = fn(*primals)
- diff_wrt = tuple(primal for primal in primals if (primal.requires_grad or primal.grad_fn is not None))
- if diff_wrt:
- return torch.autograd.grad(out, diff_wrt, grad_outputs=cotangents)
- else:
- return (torch.ones(()),) # uuugh hack...this will need to be more generic
-
- return wrapped, tuple(flat_args + flat_cotangents)
-
-
# Returns a new function g(*args, *cotangents) that computes vjps and
# sample (*args, *cotangents)
def get_vjpfull_variant(f, sample):
@@ -1310,24 +1281,73 @@
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
- @ops(tuple(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db)),
- allowed_dtypes=(torch.float32, torch.double)) # TODO: generalize
- def test_group_norm_backward(self, device, dtype, op):
- # hacky, only works since no group norm inputs can be scalars
- def was_skipped_from_batched_tensors(batched_out, batch_size):
- return batched_out.shape == (batch_size,) and all(tuple(e == 1 for e in batched_out))
+ @skipOps('TestOperators', 'test_vmap_autograd_grad', {
+ # call inplace functions
+ xfail('linalg.householder_product'), # inplace
+ xfail('matrix_exp'), # inplace
+ xfail('take'), # inplace
+
+ xfail('linalg.eig'), # all close?
+ # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
+ xfail('masked_select'),
+ xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
+ xfail('nn.functional.max_unpool2d'), # contiguous call
+ xfail('to_sparse'), # dispatch key issue
+
+ # numerical inconsistencies, look like bugs
+ skip('ldexp', dtypes=(torch.float32,), device_type='cpu'), # fails on all but mac
+ skip('__rmatmul__', dtypes=(torch.float32,), device_type='cpu'), # fails on all but windows
+ skip('matmul', dtypes=(torch.float32,), device_type='cpu'), # fails on all but windows
+ skip('nn.functional.conv_transpose3d', dtypes=(torch.float32,)), # only fails on cpu only linux
+ skip('nn.functional.layer_norm', dtypes=(torch.float32,), device_type='cpu'), # fails on windows
+ skip('linalg.lu_factor', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows
+ skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows
+ })
+ @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
+ def test_vmap_autograd_grad(self, device, dtype, op):
+ def is_differentiable(inp):
+ return isinstance(inp, Tensor) and (inp.grad_fn is not None or inp.requires_grad)
+
+ def get_flat_differentiable(pytree):
+ flattened = tree_flatten(pytree)[0]
+ return tuple(i for i in flattened if is_differentiable(i))
+
+ def get_differentiable_linked(list1, list2):
+ paired_list = zip(list1, list2)
+ paired_list = tuple((first, second) for (first, second) in paired_list if is_differentiable(first))
+ return zip(*paired_list)
+
+ def filter_none(out):
+ flattened = tree_flatten(out)[0]
+ return tuple(o for o in flattened if o is not None)
+
+ if not op.supports_autograd:
+ self.skipTest("Skipped! Autograd not supported.")
+ return
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in sample_inputs:
- cotangents = get_sample_cotangents(op, sample_input)
- f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
+ fn, primals = normalize_op_input_output(op, sample_input)
+ out = fn(*primals)
+ cotangents = tree_map(torch.randn_like, out)
+
+ def compute_grad(cotangents):
+ out_flattened = out
+ cotangents_flattened = cotangents
+ if not isinstance(out_flattened, torch.Tensor):
+ out_flattened = tree_flatten(out)[0]
+ cotangents_flattened = tree_flatten(cotangents)[0]
+ out_flattened, cotangents_flattened = get_differentiable_linked(out_flattened, cotangents_flattened)
+
+ return filter_none(
+ torch.autograd.grad(out_flattened, get_flat_differentiable(primals), cotangents_flattened,
+ retain_graph=True, allow_unused=True))
+
is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
generator = get_fallback_and_vmap_exhaustive(
- f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
+ compute_grad, (cotangents,), {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
- if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
- continue # we weren't able to use the batched tensor in autograd.grad
self.assertEqual(loop_out, batched_out)