Add support for torch.cond in vmap (#114523)

Fixes: https://github.com/pytorch/pytorch/issues/114136

Patch enables conversion of a BatchedTensor into FakeTensor and write
torch.cond vmap support using torch.where

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114523
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 8098408..a8e3067 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -3312,13 +3312,12 @@
 
         actual = fn(x, y)
         expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
-        self.assertEqual(len(counters["graph_break"]), 2)
+        self.assertEqual(len(counters["graph_break"]), 1)
         assert_dict_matches_regex(
             self,
             dict(counters["graph_break"]),
             {
                 ".*torch.vmap with body that accepts non-Tensors as input": 2,
-                "Unsupported: meta converter nyi with fake tensor propagation.": 1,
             },
         )
         self.assertEqual(actual, expected)
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index da5596d..8dc4925 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -1564,11 +1564,11 @@
                 def false_fn(*operands):
                     return inner_most_fn(*operands)
 
-            def fn(pred, operands):
+            def fn(*operands):
                 if len(operands) == 0 and len(closure_list) == 0:
                     return torch.zeros(1)
                 return cond(pred, true_fn, false_fn, operands)
-            return (pred, operands), fn
+            return operands, fn
         else:
             args, inner_fn = self._create_test_fns_for_cond(pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1)
 
@@ -1578,11 +1578,11 @@
             def false_fn(*operands):
                 return inner_most_fn(*operands) - inner_fn(*args)
 
-            def fn(pred, operands):
+            def fn(*operands):
                 if len(operands) == 0 and len(closure_list) == 0:
                     return torch.ones(1)
                 return cond(pred, true_fn, false_fn, operands)
-            return (pred, operands), fn
+            return operands, fn
 
     def _init_predicate(self, pred_type):
         if pred_type == "bool":
@@ -1624,6 +1624,142 @@
                 gm = make_fx(fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True)(*args)
                 self.assertEqual(gm(*args), eager_res)
 
+    @parametrize("predType", ["boolTensor"])
+    @parametrize("innerFnType", ["function", "module", "object"])
+    @parametrize("nOperands", [1, 2])
+    @parametrize("nClosure", [0, 1])
+    @parametrize("nesting", [0])
+    def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
+        pred = self._init_predicate(predType)
+        inner_fn = self._init_fn(innerFnType)
+        operands = [torch.ones(2, 3) + i for i in range(nOperands)]
+        closure = [torch.ones(2, 3) - i for i in range(nClosure)]
+        args, fn = self._create_test_fns_for_cond(pred, inner_fn, operands, closure, nesting)
+        eager_res = fn(*args)
+        out = torch.vmap(fn)(*args)
+        if nClosure == 0:
+            self.assertEqual(eager_res, out)
+        else:
+            self.assertEqual(eager_res, out[0])
+            self.assertEqual(eager_res, out[1])
+
+    def test_cond_vmap_simple(self):
+
+        def fn(x):
+            return torch.cond(
+                pred=torch.tensor([True]),
+                true_fn=lambda x: x + 100,
+                false_fn=lambda x: x,
+                operands=(x,)
+            )
+
+        a = torch.arange(15).reshape((3, 5))
+        res = torch.vmap(fn, in_dims=(0,))(a)
+        self.assertEqual(res.shape, (3, 5))
+        self.assertEqual(res, a + 100)
+
+    def test_cond_vmap_multiple_inputs(self):
+
+        def fn(x, y):
+            return torch.cond(
+                pred=x.sum() < y.sum(),
+                true_fn=lambda x, y: x + 100,
+                false_fn=lambda x, y: y,
+                operands=(x, y)
+            )
+
+        a = torch.arange(15).reshape(3, 5)
+        b = torch.ones_like(a) + 3
+        res = torch.vmap(fn, in_dims=(0, 0))(a, b)
+        expected = torch.tensor(
+            [
+                [100, 101, 102, 103, 104],
+                [4, 4, 4, 4, 4],
+                [4, 4, 4, 4, 4]
+            ]
+        )
+        self.assertEqual(res.shape, (3, 5))
+        self.assertEqual(expected, res)
+
+    def test_cond_vmap_single_input_with_closure(self):
+
+        a = torch.ones((3, 5)) + 3
+        c = torch.arange(5)
+
+        def fn(x):
+            return torch.cond(
+                pred=torch.tensor([True]),
+                true_fn=lambda x: x + c,
+                false_fn=lambda x: x - c,
+                operands=(x,)
+            )
+
+        res = torch.vmap(fn, in_dims=(0,))(a,)
+        self.assertEqual(a + c, res)
+
+    def test_cond_vmap_multiple_args_with_closure(self):
+
+        a = torch.ones((3, 5), dtype=torch.int64) + 3
+        b = torch.arange(15).reshape(3, 5)
+        c = torch.arange(5)
+
+        def fn(x, y):
+            return torch.cond(
+                pred=torch.tensor([False]),
+                true_fn=lambda x, y: x + c,
+                false_fn=lambda x, y: y - c,
+                operands=(x, y)
+            )
+
+        res = torch.vmap(fn)(a, b)
+        self.assertEqual(b - c, res)
+
+    @parametrize("nClosure", [0, 1])
+    def test_cond_vmap_multiple_outputs(self, nClosure):
+
+        if nClosure:
+            c = torch.ones(5, dtype=torch.int64) + 5
+
+            def fn(x):
+                return torch.cond(
+                    pred=torch.tensor([True]),
+                    true_fn=lambda x: (x + c, x - c),
+                    false_fn=lambda x: (x, x),
+                    operands=(x,)
+                )
+        else:
+            def fn(x):
+                return torch.cond(
+                    pred=torch.tensor([True]),
+                    true_fn=lambda x: (x + 1, x - 1),
+                    false_fn=lambda x: (x, x),
+                    operands=(x,)
+                )
+
+        a = torch.arange(15).reshape(3, 5)
+        res = torch.vmap(fn)(a,)
+        self.assertEqual(len(res), 2)
+        if nClosure:
+            self.assertEqual(res, (a + c, a - c))
+        else:
+            self.assertEqual(res, (a + 1, a - 1))
+
+    def test_vmap_vmap(self):
+        def fn(x):
+            return torch.cond(
+                pred=torch.tensor([True]),
+                true_fn=lambda x: x + 1,
+                false_fn=lambda x: x - 1,
+                operands=(x,)
+            )
+
+        def wrapper(x):
+            return torch.vmap(fn)(x)
+
+        a = torch.ones((3, 4, 5))
+        res = torch.vmap(wrapper)(a)
+        self.assertEqual(res, a + 1)
+
 instantiate_parametrized_tests(TestControlFlowTraced)
 
 if __name__ == '__main__':
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index f49a6b0..ad56670 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -10,7 +10,7 @@
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests,
     IS_FBCODE, freeze_rng_state, skipIfTorchDynamo, IS_WINDOWS, IS_MACOS, IS_ARM64,
-    markDynamoStrictTest, xfailIfTorchDynamo
+    markDynamoStrictTest, xfailIfTorchDynamo, TEST_WITH_TORCHDYNAMO
 )
 import torch
 import torch.nn as nn
@@ -1645,7 +1645,6 @@
         result = vmap(partial(grad(compute_loss), weights))(data, targets)
         self._compare_expected_and_result(expected, result, mechanism)
 
-    @xfailIfTorchDynamo
     def test_log_softmax(self, device):
         x = torch.randn(3, 5, device=device)
         v = torch.randn(5, device=device)
@@ -1667,6 +1666,10 @@
 
 FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name='jacrev')])
 
+FIXME_skip_jacrev_dynamo = parametrize("jacapi", [
+    subtest(jacrev, name='jacrev', decorators=[unittest.expectedFailure] if TEST_WITH_TORCHDYNAMO else None),
+    subtest(jacfwd, name='jacfwd')
+])
 
 @markDynamoStrictTest
 class TestJac(TestCase):
@@ -1745,8 +1748,7 @@
         expected = torch.diagflat(x)
         assert torch.allclose(z, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_multiple_outputs_multiple_argnums(self, device, jacapi):
         def f(x, y):
             return 2 * x + 3 * y, 4 * x + 5 * y
@@ -1768,8 +1770,7 @@
         self.assertEqual(z[1][0], expected_out1_x)
         self.assertEqual(z[1][1], expected_out1_y)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_multiple_outputs_single_argnums(self, device, jacapi):
         def f(x, y):
             return 2 * x + 3 * y, 4 * x + 5 * y
@@ -1812,8 +1813,7 @@
         self.assertTrue(isinstance(z['right'], tuple))
         self.assertEqual(z, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_multiple_inputs_pytree(self, device, jacapi):
         def f(a, b, c):
             a0, a1 = a
@@ -1838,8 +1838,7 @@
         expected = (torch.tensor(1., device=device), torch.tensor(2., device=device))
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_dimensionality(self, device, jacapi):
         def f(x):
             return x
@@ -1867,8 +1866,7 @@
         self.assertEqual(result, torch.eye(3, 3, device=device))
         self.assertEqual(aux, x.cos())
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_aux_pytree(self, device, jacapi):
         def f(x):
             y = x.clone()
@@ -1887,8 +1885,7 @@
             with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
                 _ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_outputs_can_any_pytree(self, device, jacapi):
         x = torch.randn(2, 3, device=device)
 
@@ -1922,8 +1919,7 @@
         assert isinstance(out, list)
         assert isinstance(out[0], tuple) and isinstance(out[0][1], dict)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_multiple_inputs_outputs_pytree(self, device, jacapi):
         def f(a, b, c):
             a0, a1 = a
@@ -1972,8 +1968,7 @@
         }
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_unrelated_input(self, device, jacapi):
         def f(x, y):
             return x
@@ -1988,8 +1983,7 @@
         self.assertTrue(isinstance(result, tuple))
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_unrelated_output(self, device, jacapi):
         y = torch.randn(2, 3, device=device)
 
@@ -2002,8 +1996,7 @@
         expected = x.new_zeros(2, 3, 2, 3)
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_empty_output(self, device, jacapi):
         x = torch.randn(3, device=device)
         y = torch.randn(3, device=device)
@@ -2044,8 +2037,7 @@
         assert isinstance(z, torch.Tensor)
         assert torch.allclose(z, expected0)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_argnums_defaults_to_zero(self, device, jacapi):
         def f(x, y):
             return x * 2 + y * 3
@@ -2093,7 +2085,6 @@
         with self.assertRaisesRegex(RuntimeError, "must be int"):
             jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
 
-    @xfailIfTorchDynamo
     def test_hessian_simple(self, device):
         def f(x):
             return x.sin()
@@ -2109,8 +2100,7 @@
         result = jacapi(foo)(inputs)
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_simple(self, device, jacapi):
         def f(x):
             return 3 * x ** 2
@@ -2118,8 +2108,7 @@
         x = torch.randn(2, 3, 5, device=device)
         self._test_against_reference(f, (x,), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_multi_input(self, device, jacapi):
         def f(x, y):
             return (x.cos() * x) @ y.sin()
@@ -2128,8 +2117,7 @@
         y = torch.randn(3, 5, device=device)
         self._test_against_reference(f, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_multi_input_multi_output(self, device, jacapi):
         def f(x, y):
             return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
@@ -2138,8 +2126,7 @@
         y = torch.randn(3, 5, device=device)
         self._test_against_reference(f, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_unrelated_outputs(self, device, jacapi):
         def f(x, y):
             return x, y, x, y
@@ -2148,8 +2135,7 @@
         y = torch.randn(3, device=device)
         self._test_against_reference(f, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_zero_dim(self, device, jacapi):
         # zero-dim output
         def f(x, y):
@@ -2174,8 +2160,7 @@
         y = torch.randn(1, device=device)
         self._test_against_reference(h, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_correctness_different_devices(self, device, jacapi):
         def f(x, y):
             return x * y, (x * y).to(device=device)
@@ -2184,8 +2169,7 @@
         y = torch.randn(3)
         self._test_against_reference(f, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_against_reference_default_arg(self, device, jacapi):
         def f(x, y, z=3.):
             return x * y * z
@@ -2194,8 +2178,7 @@
         y = torch.randn(3, device=device)
         self._test_against_reference(f, (x, y), jacapi)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_inplace(self, device, jacapi):
         def f(x, y):
             y.copy_(x)
@@ -2243,7 +2226,6 @@
         with self.assertRaisesRegex(ValueError, err_msg):
             jacrev(f, argnums=(0, ), chunk_size=-2)(x, y)
 
-    @xfailIfTorchDynamo
     @parametrize('_preallocate_and_copy', (True, False))
     def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
         x = torch.randn(10, 2, device=device)
@@ -2323,8 +2305,7 @@
         with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"):
             jacfwd(fn)(x)
 
-    @xfailIfTorchDynamo
-    @jacrev_and_jacfwd
+    @FIXME_skip_jacrev_dynamo
     def test_jac_with_non_tensor_args(self, device, jacapi):
         def f(t, int_x):
             return t + int_x
@@ -2345,7 +2326,6 @@
         result = hessian(foo)(inputs)
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
     def test_hessian_vectorize_correctness_simple(self, device):
         def f(x):
             return (3 * x ** 2).sum()
@@ -2353,7 +2333,6 @@
         x = torch.randn(2, 3, 5, device=device)
         self._test_against_reference(f, (x,))
 
-    @xfailIfTorchDynamo
     def test_hessian_vectorize_correctness_multi_input(self, device):
         def f(x, y, z):
             return ((x.relu() * x) @ y.sin() @ z).sum()
@@ -2363,7 +2342,6 @@
         z = torch.randn(5, 5, device=device)
         self._test_against_reference(f, (x, y, z))
 
-    @xfailIfTorchDynamo
     def test_hessian_vectorize_correctness_unrelated_outputs(self, device):
         # output unrelated to one input
         def f(x, y):
@@ -2381,7 +2359,6 @@
         y = torch.randn(3, device=device)
         self._test_against_reference(f, (x, y))
 
-    @xfailIfTorchDynamo
     def test_jacfwd_different_levels(self, device):
         # Test case from:
         # https://github.com/pytorch/functorch/issues/597
@@ -2852,7 +2829,6 @@
 @markDynamoStrictTest
 class TestVmapJvpInplaceView(TestCase):
     # Case 1 in [Forward Grad View/inplace]
-    @xfailIfTorchDynamo
     def test_all_dual_no_view(self, device):
         B = 2
 
@@ -2881,7 +2857,6 @@
         self.assertEqual(out_tangent, yt.expand(B, 3))
 
     # Case 2 in [Forward Grad View/inplace]
-    @xfailIfTorchDynamo
     def test_all_dual_base_view_inplace(self, device):
         B = 2
 
@@ -2919,7 +2894,6 @@
         self.assertEqual(x.movedim(2, 0), expected)
 
     # Case 3 in [Forward Grad View/inplace]
-    @xfailIfTorchDynamo
     def test_all_dual_base_inplace(self, device):
         B = 2
 
@@ -2948,7 +2922,6 @@
         self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2])
 
     # Case 4 in [Forward Grad View/inplace]
-    @xfailIfTorchDynamo
     def test_right_dual_view_prop(self, device):
         B = 2
 
@@ -2980,7 +2953,6 @@
         self.assertEqual(tangents[1], expected_tangent_1)
 
     # Case 5 in [Forward Grad View/inplace]
-    @xfailIfTorchDynamo
     def test_right_dual_base_prop(self, device):
         B = 2
 
@@ -4156,7 +4128,6 @@
 
         self.assertEqual(result, expected)
 
-    @xfailIfTorchDynamo
     @parametrize('mechanism', ["make_functional", "functional_call"])
     def test_ensemble_regression(self, device, mechanism):
         def make_spirals(n_samples, noise_std=0., rotations=1.):
@@ -4241,7 +4212,6 @@
         self.assertEqual(result_loss, expected_loss)
         self.assertEqual(result_weights, expected_weights)
 
-    @xfailIfTorchDynamo
     @parametrize("dropout_layer", [
         subtest(nn.Dropout, 'Dropout'),
         subtest(nn.AlphaDropout, 'AlphaDropout'),
@@ -4388,7 +4358,6 @@
         self.assertEqual(inpt1, inpt2)
         self.assertEqual(inpt1, inpt3)
 
-    @xfailIfTorchDynamo
     def test_simple_view(self, device):
 
         def f(x: torch.Tensor) -> torch.Tensor:
@@ -4398,7 +4367,6 @@
             return x
         self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
 
-    @xfailIfTorchDynamo
     def test_multioutput_view(self, device):
 
         def f(x: torch.Tensor) -> torch.Tensor:
@@ -4488,7 +4456,6 @@
         self.assertEqual(out1, out2)
         self.assertEqual(inpt1, inpt2)
 
-    @xfailIfTorchDynamo
     @unittest.skipIf(IS_FBCODE, 'fails in fbcode')
     def test_vmap_functionalize_jvp(self, device):
 
@@ -4815,7 +4782,6 @@
         ggx = grad(grad_f_sum)(x)
         self.assertEqual(ggx, -x.sin())
 
-    @xfailIfTorchDynamo
     def test_vmap_grad_sum(self, device):
         x = torch.randn(2, 3, device=device)
         gx = vmap(grad(sum_pyop), (0, None))(x, 0)
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index dacadec..95107ef 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -25,6 +25,7 @@
 from torch.testing._internal.common_utils import (
     parametrize,
     instantiate_parametrized_tests,
+    IS_WINDOWS,
     subtest,
     skipIfRocm,
     TEST_WITH_TORCHDYNAMO,
@@ -2155,6 +2156,8 @@
         self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
         self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
 
+    @unittest.skipIf(IS_WINDOWS,
+                     reason="Windows not yet supported for torch.compile")
     def test_is_contiguous(self):
         def foo(x):
             if x.is_contiguous():
@@ -2211,6 +2214,16 @@
         with self.assertRaisesRegex(RuntimeError, msg):
             vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
 
+        for mf in (torch.channels_last, torch.channels_last_3d):
+            @torch.compile(backend="eager", fullgraph=True)
+            def f(x):
+                if x.is_contiguous(memory_format=mf):
+                    return x.sin()
+                return x.cos()
+
+            with self.assertRaisesRegex(RuntimeError, msg):
+                vmap(f)(torch.randn(3, 3))
+
     def test_unsqueeze(self):
         op = torch.unsqueeze
         test = self._vmap_view_test
@@ -2840,7 +2853,6 @@
         res = vmap(foo)(x)
         self.assertEqual(res, x.conj())
 
-    @xfailIfTorchDynamo
     def test_mode_key(self):
         def vmap_f(x):
             return x + torch.randn(())
@@ -5092,8 +5104,9 @@
     @skipIfTorchDynamo
     @parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
     def test_fails_with_autograd_function(self, device, transform):
+        failed_build_envs = ('linux-focal-py3.8-clang10', 'linux-focal-py3.11-clang10')
         if (device == 'cpu' and transform in ['grad', 'vmap'] and
-                TEST_WITH_TORCHDYNAMO and os.getenv('BUILD_ENVIRONMENT', '') == 'linux-focal-py3.8-clang10'):
+                TEST_WITH_TORCHDYNAMO and os.getenv('BUILD_ENVIRONMENT', '') in failed_build_envs):
             raise unittest.SkipTest("Unexpected successes on focal with dynamo," +
                                     " see https://github.com/pytorch/pytorch/issues/107173")
 
@@ -5352,7 +5365,6 @@
                 RuntimeError, "Nested tensors can only be vmapped over dim=0"):
             vmap(f, in_dims=2)(x)
 
-    @xfailIfTorchDynamo
     def test_nt_with_nonzero_out_dim_raises(self, device):
         def f(x):
             return x
@@ -5362,8 +5374,6 @@
                 RuntimeError, "Nested tensors can only be vmapped over dim=0"):
             vmap(f, out_dims=2)(x)
 
-    @xfailIfTorchDynamo
-    @allowVmapFallbackUsage
     def test_fallback_with_nt_and_batched_dense_with_nonzero_bdim_raises(self, device):
         def f(x, y):
             return x @ y
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 14a5965..2f56229 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -22,6 +22,7 @@
 from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
 from torch.fx.passes.fake_tensor_prop import FakeTensorProp
 from torch._dynamo.testing import rand_strided
+from torch._C._functorch import is_batchedtensor, _add_batch_dim, get_unwrapped
 from torch.testing import FileCheck
 import unittest
 import torch._prims as prims
@@ -201,6 +202,23 @@
                 FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
                 FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
 
+    def test_batch_tensor(self):
+        x = torch.rand((3, 4, 5))
+        b = _add_batch_dim(x, 0, 0)
+        mode = FakeTensorMode()
+        fake_b = mode.from_tensor(b)
+        prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
+
+        b1 = _add_batch_dim(x, 1, 1)
+        b2 = _add_batch_dim(b1, 0, 2)
+        fake_b2 = mode.from_tensor(b2)
+        prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
+        self.assertTrue(is_batchedtensor(fake_b2))
+        fake_b1 = get_unwrapped(fake_b2)
+        self.assertTrue(is_batchedtensor(fake_b1))
+        fake_tensor = get_unwrapped(fake_b1)
+        self.assertIsInstance(fake_tensor, FakeTensor)
+
     def test_constructor(self):
         with FakeTensorMode():
             x = torch.rand([4, 4], device="cpu")
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index be76fa6..1e0754b 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -546,8 +546,12 @@
             all_diffs = []
             for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
                 # We check the meta data associated with meta["example_value"]
-                meta1 = _extract_tensor_metadata(var1.proxy.node.meta["example_value"])
-                meta2 = _extract_tensor_metadata(var2.proxy.node.meta["example_value"])
+                meta1 = _extract_tensor_metadata(
+                    var1.proxy.node.meta["example_value"], include_contiguity=False
+                )
+                meta2 = _extract_tensor_metadata(
+                    var2.proxy.node.meta["example_value"], include_contiguity=False
+                )
                 if meta1 != meta2:
                     all_diffs.append((f"pair{i}:", meta1, meta2))
             return all_diffs
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 2f68e77..bac958f 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -158,13 +158,18 @@
                 [int(s) if is_symbolic(s) else s for s in value.size()]
             )
             props["stride"] = tuple(value.stride())
-            props["is_contiguous"] = tuple(
-                [
-                    x
-                    for x in torch._prims_common._memory_formats
-                    if value.is_contiguous(memory_format=x)
-                ]
-            )
+            if torch._C._functorch.is_batchedtensor(value):
+                # Batched tensors does not support contiguity patterns, so
+                # we refrain from computing the `is_contiguous` property
+                props["is_contiguous"] = None
+            else:
+                props["is_contiguous"] = tuple(
+                    [
+                        x
+                        for x in torch._prims_common._memory_formats
+                        if value.is_contiguous(memory_format=x)
+                    ]
+                )
         return props
 
     def dynamic_getattr(self, tx, name):
diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py
index 0ba631a..69bc77f 100644
--- a/torch/_functorch/vmap.py
+++ b/torch/_functorch/vmap.py
@@ -274,10 +274,17 @@
         return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
                              args_spec, out_dims, randomness, **kwargs)
 
-    # If chunk_size is not specified.
-    return _flat_vmap(
-        func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
-    )
+    from torch._dynamo import disable
+
+    # remove @disable once #114306 is fixed
+    @disable
+    def wrapper():
+        # If chunk_size is not specified.
+        return _flat_vmap(
+            func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
+        )
+
+    return wrapper()
 
 def get_chunk_sizes(total_elems, chunk_size):
     n_chunks = n_chunks = total_elems // chunk_size
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 96b522b..98c25d2 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -8,6 +8,12 @@
 import torch.utils._pytree as pytree
 
 from torch._C import DispatchKey
+from torch._C._functorch import (
+    _add_batch_dim,
+    get_unwrapped,
+    is_batchedtensor,
+    maybe_get_bdim,
+)
 from torch._functorch.utils import exposed_in
 
 from torch._higher_order_ops.utils import autograd_not_implemented
@@ -404,3 +410,50 @@
             unwrapped_pred, functional_true, functional_false, unwrapped_inputs
         )
         return ctx.wrap_tensors(cond_return)
+
+
+@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
+def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
+    assert isinstance(
+        inputs, (list, tuple)
+    ), "Cond inputs must be a list or tuple of tensors"
+    assert all(
+        isinstance(i, torch.Tensor) for i in inputs
+    ), "Cond inputs must be a list of tensors"
+
+    pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
+
+    # unbatched tensors are not vmapped
+    tensors, in_dims = zip(
+        *[
+            (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
+            for t in inputs
+        ]
+    )
+
+    if is_batchedtensor(pred):
+        # prepend "pred" and vmap everything
+        tensors = (pred_,) + tensors
+        in_dims = (0,) + in_dims
+
+        def fn(p, *args):
+            t = true_fn(*args)
+            f = false_fn(*args)
+            return torch.where(p, t[0], f[0])
+
+        with interpreter.lower():
+            result = torch.vmap(fn, in_dims=in_dims)(*tensors)
+
+    else:
+        # predicate is known at this stage and it is a boolean expression or a
+        # tensor with one element.
+        true_fn = torch.vmap(true_fn, in_dims=in_dims)
+        false_fn = torch.vmap(false_fn, in_dims=in_dims)
+
+        with interpreter.lower():
+            result = cond_op(pred, true_fn, false_fn, tensors)
+
+    if not isinstance(result, tuple):
+        result = (result,)
+    lvl = interpreter.level()
+    return tuple([_add_batch_dim(r, 0, lvl) for r in result])
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index ebc1440..dbb122a 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -188,6 +188,9 @@
         reapply_views = torch._C._functionalization_reapply_views_tls()
         unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
         return is_fake(unwrapped)
+    elif isinstance(x, torch.Tensor) and torch._C._functorch.is_batchedtensor(x):
+        unwrapped = torch._C._functorch.get_unwrapped(x)
+        return is_fake(unwrapped)
     return False
 
 
@@ -206,6 +209,9 @@
         reapply_views = torch._C._functionalization_reapply_views_tls()
         unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
         return maybe_get_fake_mode(unwrapped)
+    elif isinstance(t, torch.Tensor) and torch._C._functorch.is_batchedtensor(t):
+        unwrapped = torch._C._functorch.get_unwrapped(t)
+        return maybe_get_fake_mode(unwrapped)
     return None
 
 
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index ba7e0d2..40442bf 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -5,9 +5,14 @@
 
 import torch
 from torch._C._functorch import (
+    _add_batch_dim,
     _unwrap_functional_tensor,
     _wrap_functional_tensor,
     current_level,
+    get_unwrapped,
+    is_batchedtensor,
+    maybe_get_bdim,
+    maybe_get_level,
     peek_interpreter_stack,
     TransformType,
 )
@@ -128,7 +133,7 @@
         # hold a weak ref to self, otherwise it will be kept alive
         # by the del_ten closure
         self_weak_ref = weakref.ref(self)
-        if t.is_sparse or t.is_mkldnn:
+        if t.is_sparse or t.is_mkldnn or is_batchedtensor(t):
             weak_st = None
         else:
             weak_st = StorageWeakRef(t._typed_storage())
@@ -313,6 +318,30 @@
                     if t.requires_grad and not is_leaf:
                         with torch.enable_grad():
                             r = r.clone()
+                elif is_batchedtensor(t):
+                    # Wraps a BatchedTensor in a FakeTensor
+                    def _to_fake_tensor(t):
+                        if is_batchedtensor(t):
+                            ft = _to_fake_tensor(get_unwrapped(t))
+                            lvl = maybe_get_level(t)
+                            bdim = maybe_get_bdim(t)
+                            r = _add_batch_dim(ft, bdim, lvl)
+                        else:
+                            # regular tensor
+                            sizes = t.size()
+                            strides = t.stride()
+                            r = callback(
+                                lambda: torch.empty_strided(
+                                    sizes,
+                                    strides,
+                                    dtype=t.dtype,
+                                    device="meta",
+                                )
+                            )
+                        return r
+
+                    r = _to_fake_tensor(t)
+
                 elif t._is_view():
                     # Construct views in two steps: recursively meta-fy their
                     # base, and then create view(s) off that.  NB: doing it
@@ -511,7 +540,9 @@
                                 r = r.clone(memory_format=torch.preserve_format)
 
                     # Graph-Break for wrapped tensors
-                    if torch._C._functorch.is_functorch_wrapped_tensor(t):
+                    if not is_batchedtensor(
+                        t
+                    ) and torch._C._functorch.is_functorch_wrapped_tensor(t):
                         return NotImplemented
 
                     s = t.untyped_storage()
diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py
index 69260a5..6208980 100644
--- a/torch/fx/passes/shape_prop.py
+++ b/torch/fx/passes/shape_prop.py
@@ -26,7 +26,7 @@
     is_quantized : bool
     qparams: Dict[str, Any]
 
-def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
+def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
     """
     Extract a TensorMetadata NamedTuple describing `result`.
     """
@@ -35,18 +35,18 @@
     requires_grad = result.requires_grad
     stride = result.stride()
 
-    memory_formats = {
-        torch.contiguous_format,
-        torch.channels_last,
-        torch.channels_last_3d,
-    }
-
     memory_format = None
 
-    for query_format in memory_formats:
-        if result.is_contiguous(memory_format=query_format):
-            memory_format = query_format
-            break
+    if include_contiguity:
+        memory_formats = {
+            torch.contiguous_format,
+            torch.channels_last,
+            torch.channels_last_3d,
+        }
+        for query_format in memory_formats:
+            if result.is_contiguous(memory_format=query_format):
+                memory_format = query_format
+                break
 
     is_quantized = result.is_quantized
     qparams: Dict[str, Any] = {}