[functorch] pytree output support for vmap
diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py
index 68ae2a8..2d79fc8 100644
--- a/functorch/functorch/_src/eager_transforms.py
+++ b/functorch/functorch/_src/eager_transforms.py
@@ -4,6 +4,7 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.utils._pytree import tree_flatten, tree_unflatten
+from .pytree_hacks import tree_map, tree_map_
 import gc
 
 from .vmap import vmap
@@ -16,16 +17,6 @@
     _grad_decrement_nesting,
 )
 
-# TODO: replace this with tree_map from core
-def tree_map(fn, pytree):
-    flat_args, spec = tree_flatten(pytree)
-    return tree_unflatten([fn(arg) for arg in flat_args], spec)
-
-def tree_map_(fn_, pytree):
-    flat_args, _ = tree_flatten(pytree)
-    [fn_(arg) for arg in flat_args]
-    return pytree
-
 # TODO: replace all of these with pytrees
 def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
     if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
diff --git a/functorch/functorch/_src/pytree_hacks.py b/functorch/functorch/_src/pytree_hacks.py
new file mode 100644
index 0000000..2aef9f6
--- /dev/null
+++ b/functorch/functorch/_src/pytree_hacks.py
@@ -0,0 +1,39 @@
+import torch.utils._pytree as _pytree
+from torch.utils._pytree import tree_flatten, tree_unflatten
+
+# TODO: The following function should only be used with vmap.
+# torch.return_types should be registered as PyTree nodes.
+# I can't figure out how to do that, so we are turning all of them
+# into normal Tuples for now (this is what vmap used to do anyways).
+# We probably want some special behavior for named tuples?
+def tree_flatten_hack(pytree):
+    if _pytree._is_leaf(pytree) and not isinstance(pytree, tuple):
+        return [pytree], _pytree.LeafSpec()
+
+    if isinstance(pytree, tuple):
+        typ = tuple
+    else:
+        typ = type(pytree)
+
+    flatten_fn = _pytree.SUPPORTED_NODES[typ].flatten_fn
+    child_pytrees, context = flatten_fn(pytree)
+
+    # Recursively flatten the children
+    result : List[Any] = []
+    children_specs : List['TreeSpec'] = []
+    for child in child_pytrees:
+        flat, child_spec = tree_flatten_hack(child)
+        result += flat
+        children_specs.append(child_spec)
+
+    return result, _pytree.TreeSpec(typ, context, children_specs)
+
+# TODO: replace this with tree_map from core
+def tree_map(fn, pytree):
+    flat_args, spec = tree_flatten(pytree)
+    return tree_unflatten([fn(arg) for arg in flat_args], spec)
+
+def tree_map_(fn_, pytree):
+    flat_args, _ = tree_flatten(pytree)
+    [fn_(arg) for arg in flat_args]
+    return pytree
diff --git a/functorch/functorch/_src/vmap.py b/functorch/functorch/_src/vmap.py
index 066c03f..028a94f 100644
--- a/functorch/functorch/_src/vmap.py
+++ b/functorch/functorch/_src/vmap.py
@@ -3,6 +3,8 @@
 from torch import Tensor
 from typing import Any, Callable, Optional, Tuple, Union, List
 from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
+from .pytree_hacks import tree_flatten_hack, tree_map_
+from functools import partial
 import warnings
 
 from functorch._C import (
@@ -96,46 +98,54 @@
         batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
         out_dims: out_dims_t,
         vmap_level: int, batch_size: int, func: Callable) -> Tuple:
-    num_outputs = _num_outputs(batched_outputs)
-    out_dims_as_tuple = _as_tuple(
-        out_dims, num_outputs,
-        lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
-                f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
+    flat_batched_outputs, output_spec = tree_flatten_hack(batched_outputs)
 
-    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
-    # There is something wrong with our type bindings for functions that begin
-    # with '_', see #40397.
-    if isinstance(batched_outputs, Tensor):
-        out_dim = out_dims_as_tuple[0]
-        return _remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim)  # type: ignore
-    return tuple(_remove_batch_dim(out, vmap_level, batch_size, out_dim)  # type: ignore
-                 for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
-
-# Checks that `fn` returned one or more Tensors and nothing else.
-# NB: A python function that return multiple arguments returns a single tuple,
-# so we are effectively checking that `outputs` is a single Tensor or a tuple of
-# Tensors.
-def _validate_outputs(outputs: Any, func: Callable) -> None:
-    if isinstance(outputs, Tensor):
-        return
-    if not isinstance(outputs, tuple):
-        raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
-                         f'Tensors, got type {type(outputs)} as the return.')
-    for idx, output in enumerate(outputs):
-        if isinstance(output, Tensor):
+    for out in flat_batched_outputs:
+        if isinstance(out, torch.Tensor):
             continue
         raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
-                         f'Tensors, got type {type(output)} for return {idx}.')
+                         f'Tensors, got type {type(out)} as a return.')
 
-def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
+    def incompatible_error():
+        raise ValueError(
+            f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
+            f'out_dims is not compatible with the structure of `outputs`. '
+            f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
+            f'has structure {output_spec}.')
+
+    if isinstance(batched_outputs, torch.Tensor):
+        # Some weird edge case requires us to spell out the following
+        # see test_out_dims_edge_case
+        if isinstance(out_dims, int):
+            flat_out_dims = [out_dims] 
+        elif isinstance(out_dims, tuple) and len(out_dims) == 1:
+            flat_out_dims = out_dims
+            out_dims = out_dims[0]
+        else:
+            incompatible_error()
+    else:
+        flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
+        if flat_out_dims is None:
+            incompatible_error()
+
+    flat_outputs = [
+        _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
+        for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
+    ]
+    return tree_unflatten(flat_outputs, output_spec)
+
+def _check_int(x, func, out_dims):
+    if isinstance(x, int):
+        return
+    raise ValueError(
+        f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
+        f'an int or a python collection of ints representing where in the outputs the '
+        f'vmapped dimension should appear.')
+
+def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
     if isinstance(out_dims, int):
         return
-    if not isinstance(out_dims, tuple) or \
-            not all([isinstance(out_dim, int) for out_dim in out_dims]):
-        raise ValueError(
-            f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
-            f'an int or a tuple of int representing where in the outputs the '
-            f'vmapped dimension should appear.')
+    tree_map_(partial(_check_int, func=func, out_dims=out_dims), out_dims)
 
 def _get_name(func: Callable):
     if hasattr(func, '__name__'):
@@ -250,13 +260,12 @@
 def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
     @functools.wraps(func)
     def wrapped(*args):
-        _check_out_dims_is_int_or_int_tuple(out_dims, func)
+        _check_out_dims_is_int_or_int_pytree(out_dims, func)
         vmap_level = _vmap_increment_nesting()
         torch._C._vmapmode_increment_nesting()
         try:
             batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
             batched_outputs = func(*batched_inputs)
-            _validate_outputs(batched_outputs, func)
             return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
         finally:
             torch._C._vmapmode_decrement_nesting()
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 24f2e92..3ac1b9a 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -27,13 +27,13 @@
 
 class TestVmapAPI(TestCase):
     def test_non_tensor_output_raises(self):
-        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
+        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as a return"):
             output = vmap(lambda x: 3.14)(torch.ones(3))
 
         def multiple_outputs(x):
             return x, 3
 
-        with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
+        with self.assertRaisesRegex(ValueError, "got type <class 'int'> as a return"):
             vmap(multiple_outputs)(torch.ones(3))
 
     def test_different_map_dim_size_raises(self):
@@ -90,7 +90,7 @@
         self.assertEqual(outputs[0], x * x)
         self.assertEqual(outputs[1], x * x * x)
 
-    def test_multiple_outputs_error_cases(self):
+    def test_multiple_outputs(self):
         # This is the same thing as
         # def returns_tuple_of_tensors(x):
         #     return x, x
@@ -107,13 +107,8 @@
 
         # should not throw
         vmap(returns_tuple_of_tensors)(x)
-
-        # jax supports these, but we don't yet
-        msg = "must only return Tensors, got type <class 'list'>"
-        with self.assertRaisesRegex(ValueError, msg):
-            vmap(returns_list_of_two_tensors)(x)
-        with self.assertRaisesRegex(ValueError, msg):
-            vmap(returns_list_of_one_tensor)(x)
+        vmap(returns_list_of_two_tensors)(x)
+        vmap(returns_list_of_one_tensor)(x)
 
     def test_nested_with_same_map_dim(self):
         x = torch.randn(2, 3, 5)
@@ -267,8 +262,59 @@
         result = vmap(foo, out_dims=(1,))(tensor)
         self.assertEqual(result, expected)
 
-    def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
-        msg = '`out_dims` must be an int or a tuple of int'
+    def test_pytree_returns(self):
+        x = torch.randn(2, 3)
+
+        def f(x):
+            y = x.sin()
+            return y, (y, y), [y, (y, y)]
+
+        y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x)
+        self.assertEqual(y0, x.sin())
+        self.assertEqual(y0, y1)
+        self.assertEqual(y2, y1)
+        self.assertEqual(y2, y3)
+        self.assertEqual(y4, y3)
+        self.assertEqual(y5, y4)
+
+    def test_pytree_returns_outdims(self):
+        x = torch.randn(2, 3)
+
+        def f(x):
+            y = x.sin()
+            return y, (y, y)
+
+        y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x)
+        self.assertEqual(y0, x.sin())
+        self.assertEqual(y1, x.sin())
+        self.assertEqual(y2, x.sin().t())
+
+    def test_pytree_returns_broadcast_simple(self):
+        x = torch.randn(2, 3)
+
+        def f(x):
+            y = x.sin()
+            return y, (y, y)
+
+        y0, (y1, y2) = vmap(f, out_dims=1)(x)
+        self.assertEqual(y0, x.sin().t())
+        self.assertEqual(y1, y0)
+        self.assertEqual(y2, y0)
+
+    def test_pytree_returns_broadcast_nested(self):
+        x = torch.randn(2, 3)
+
+        def f(x):
+            y = x.sin()
+            return y, (y, y)
+
+        y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x)
+        self.assertEqual(y0, x.sin())
+        self.assertEqual(y1, y0.t())
+        self.assertEqual(y2, y0.t())
+
+    def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
+        msg = 'must be an int or a python collection of ints'
         tensor = torch.randn(2, 3)
         with self.assertRaisesRegex(ValueError, msg):
             vmap(lambda x: x, out_dims='lol')(tensor)
@@ -280,7 +326,7 @@
             vmap(lambda x: x, out_dims=(None,))(tensor)
 
     def test_out_dims_and_num_outputs_mismatch_err_msg(self):
-        msg = '`out_dims` must have one dim per output'
+        msg = 'not compatible'
         x = torch.randn(2, 3, 5)
 
         # Too many out_dims
@@ -2639,9 +2685,9 @@
                 self.assertEqual(loop_out, batched_out)
 
 
-instantiate_device_type_tests(TestVmapOperators, globals())
-
 only_for = ("cpu", "cuda")
+instantiate_device_type_tests(TestVmapOperators, globals(), only_for=only_for)
+
 instantiate_device_type_tests(
     TestVmapBatchedGradient,
     globals(),