update vmap to accept nones (#91644)

* Fixes https://github.com/pytorch/functorch/issues/1082
* Fixes https://github.com/pytorch/functorch/issues/439

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91644
Approved by: https://github.com/kshitij12345, https://github.com/Chillee
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 954c440..32b4125 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -73,13 +73,13 @@
 
 class TestVmapAPI(TestCase):
     def test_non_tensor_output_raises(self):
-        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as a return"):
+        with self.assertRaisesRegex(ValueError, "got type <class 'float'>"):
             vmap(lambda x: 3.14)(torch.ones(3))
 
         def multiple_outputs(x):
             return x, 3
 
-        with self.assertRaisesRegex(ValueError, "got type <class 'int'> as a return"):
+        with self.assertRaisesRegex(ValueError, "got type <class 'int'>"):
             vmap(multiple_outputs)(torch.ones(3))
 
     def test_different_map_dim_size_raises(self):
@@ -317,6 +317,49 @@
         result = vmap(foo, out_dims=(1,))(tensor)
         self.assertEqual(result, expected)
 
+    def test_out_dims_none_tuple(self):
+        def foo(x):
+            return x, 'hello world'
+
+        tensor = torch.randn(2, 3)
+        result = vmap(foo, out_dims=(0, None))(tensor)
+        self.assertEqual(result[1], 'hello world')
+        self.assertEqual(result[0], tensor)
+
+        def foo(x):
+            x.add_(1)
+            return None, 'hello world'
+        result = vmap(foo, out_dims=(None, None))(tensor)
+        self.assertEqual(result, (None, 'hello world'))
+
+
+    def test_out_dims_none(self):
+        def foo(x):
+            return x
+
+        tensor = torch.randn(2, 3)
+        with self.assertRaisesRegex(ValueError, 'can not return a BatchedTensor when out_dim is None'):
+            vmap(foo, out_dims=None)(tensor)
+
+        def foo(x):
+            x.add_(1)
+            return 'hello world'
+        result = vmap(foo, out_dims=None)(tensor)
+        self.assertEqual(result, 'hello world')
+
+    def test_out_dims_normal_tensor(self):
+
+        def foo(x):
+            return torch.arange(3)
+
+        tensor = torch.randn(2, 3)
+        result = vmap(foo)(tensor)
+        self.assertEqual(result.shape, [2, 3])
+
+        result = vmap(foo, out_dims=None)(tensor)
+        self.assertEqual(result, torch.arange(3))
+
+
     def test_pytree_returns(self):
         x = torch.randn(2, 3)
 
@@ -382,16 +425,12 @@
         self.assertEqual(y2, y0.t())
 
     def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
-        msg = 'must be an int or a python collection of ints'
+        msg = 'must be an int, None or a python collection of ints'
         tensor = torch.randn(2, 3)
         with self.assertRaisesRegex(ValueError, msg):
             vmap(lambda x: x, out_dims='lol')(tensor)
         with self.assertRaisesRegex(ValueError, msg):
             vmap(lambda x: x, out_dims=('lol',))(tensor)
-        with self.assertRaisesRegex(ValueError, msg):
-            vmap(lambda x: x, out_dims=None)(tensor)
-        with self.assertRaisesRegex(ValueError, msg):
-            vmap(lambda x: x, out_dims=(None,))(tensor)
 
     def test_out_dims_and_num_outputs_mismatch_err_msg(self):
         msg = 'not compatible'
diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py
index db3116a..0cae1b9 100644
--- a/torch/_functorch/vmap.py
+++ b/torch/_functorch/vmap.py
@@ -20,6 +20,7 @@
     _remove_batch_dim,
     _vmap_decrement_nesting,
     _vmap_increment_nesting,
+    is_batchedtensor,
 )
 from torch._functorch.utils import exposed_in
 
@@ -130,21 +131,33 @@
                       for in_dim, arg in zip(flat_in_dims, flat_args)]
     return tree_unflatten(batched_inputs, args_spec)
 
+
+def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
+
+    if out_dim is None:
+        if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output):
+            raise ValueError(
+                f'vmap({name}, ...): `{name}` can not return a '
+                f'BatchedTensor when out_dim is None'
+            )
+        return batched_output
+
+    # out_dim is non None
+    if not isinstance(batched_output, torch.Tensor):
+        raise ValueError(f'vmap({name}, ...): `{name}` must only return '
+                         f'Tensors, got type {type(batched_output)}. '
+                         'Did you mean to set out_dim= to None for output?')
+
+    return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
+
+
 # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
-
-
 def _unwrap_batched(
         batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
         out_dims: out_dims_t,
         vmap_level: int, batch_size: int, func: Callable) -> Tuple:
     flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
 
-    for out in flat_batched_outputs:
-        if isinstance(out, torch.Tensor):
-            continue
-        raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
-                         f'Tensors, got type {type(out)} as a return.')
-
     def incompatible_error():
         raise ValueError(
             f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
@@ -159,7 +172,8 @@
             flat_out_dims = [out_dims]
         elif isinstance(out_dims, tuple) and len(out_dims) == 1:
             flat_out_dims = out_dims
-            out_dims = out_dims[0]
+        elif out_dims is None:
+            flat_out_dims = [out_dims]
         else:
             incompatible_error()
     else:
@@ -168,25 +182,27 @@
             incompatible_error()
 
     flat_outputs = [
-        _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
+        _maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
         for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
     ]
     return tree_unflatten(flat_outputs, output_spec)
 
 
-def _check_int(x, func, out_dims):
+def _check_int_or_none(x, func, out_dims):
     if isinstance(x, int):
         return
+    if x is None:
+        return
     raise ValueError(
         f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
-        f'an int or a python collection of ints representing where in the outputs the '
+        f'an int, None or a python collection of ints representing where in the outputs the '
         f'vmapped dimension should appear.')
 
 
 def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
     if isinstance(out_dims, int):
         return
-    tree_map_(partial(_check_int, func=func, out_dims=out_dims), out_dims)
+    tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
 
 
 def _get_name(func: Callable):