[HigherOrderOp] add pytree operands tests for cond (#112661)

This is a follow-up of #111611. After this PR, we allow pytree with tensor-only leaves as operands of branches.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112661
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 2653f08..afc8dd6 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -3014,7 +3014,7 @@
         example_inputs = (torch.rand(5),)
         with self.assertRaisesRegex(
             RuntimeError,
-            "Expect operands to be a tuple of Tensors, but got",
+            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
         ):
             f_non_list_operands(*example_inputs)
 
@@ -3027,7 +3027,8 @@
 
         example_inputs = (torch.rand(5),)
         with self.assertRaisesRegex(
-            RuntimeError, "Expect operands to be a tuple of Tensors, but got"
+            RuntimeError,
+            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
         ):
             f_non_tensor_operands(*example_inputs)
 
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index bfa29ad..e3c7ede 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -2122,6 +2122,93 @@
  'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
         )
 
+    def test_cond_pytree_operands(self):
+        def _construct_pytree():
+            a = torch.randn(3, 3)
+            b = torch.randn(3, 3)
+            c = torch.randn(3, 3)
+            d = torch.randn(3, 3)
+            e = torch.randn(3, 3)
+            f = torch.randn(3, 3)
+            g = torch.randn(3, 3)
+            return (a, [[[b]]], c, (d, (e,), f), {"g": g})
+
+        pred = torch.tensor(True)
+        inp = _construct_pytree()
+
+        def _reduce_sum(flattened):
+            init = 0
+            for val in flattened:
+                init += val
+            return init
+
+        def _reduce_max(flattened):
+            init = flattened[0]
+            for val in flattened:
+                init = max(val, init)
+            return init
+
+        def true_fn(pytree_in):
+            flattened, spec = pytree.tree_flatten(pytree_in)
+            return _reduce_sum(flattened)
+
+        def false_fn(pytree_in):
+            flattened, spec = pytree.tree_flatten(pytree_in)
+            return _reduce_max(flattened)
+
+        def fn(pred, pytree_in):
+            return torch.cond(pred, true_fn, false_fn, [pytree_in])
+
+        backend = EagerAndRecordGraphs()
+        cnt = CompileCounterWithBackend(backend)
+        compiled_res = torch.compile(fn, backend=backend)(pred, inp)
+        eager_res = fn(pred, inp)
+        self.assertEqual(compiled_res, eager_res)
+        graph = backend.graphs[0]
+
+        # Dynamic shapes produce a slightly different graph.
+        if check_dynamic_shape_capture():
+            return
+
+        self.assertExpectedInline(
+            graph.code.strip(),
+            """\
+def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
+    l_pred_ = L_pred_
+    l_pytree_in_0_ = L_pytree_in_0_
+    l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
+    l_pytree_in_2_ = L_pytree_in_2_
+    l_pytree_in_3_0_ = L_pytree_in_3_0_
+    l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
+    l_pytree_in_3_2_ = L_pytree_in_3_2_
+    l_pytree_in_4_g_ = L_pytree_in_4_g_
+    cond_true_0 = self.cond_true_0
+    cond_false_0 = self.cond_false_0
+    cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]);  l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
+    return (cond,)""",  # noqa: B950
+        )
+
+    def test_cond_pytree_operands_with_non_tensor_leaves(self):
+        def fn(pred, pytree_in):
+            return torch.cond(
+                pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,)
+            )
+
+        pred = torch.tensor(True)
+        for pytree_in in [(1,), ("string",), (1.0,)]:
+            with self.assertRaisesRegex(
+                RuntimeError,
+                r"Expect operands to be a tuple of possibly nested dict/list/tuple",
+            ):
+                fn(pred, pytree_in)
+
+        for pytree_in in [(1,), ("string",), (1.0,)]:
+            with self.assertRaisesRegex(
+                torch._dynamo.exc.UncapturedHigherOrderOpError,
+                r"Cond doesn't work unless it is captured completely with torch.compile",
+            ):
+                torch.compile(fn, backend="eager")(pred, pytree_in)
+
 
 class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
     def run(self, result=None):
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index fbe5a77..8acb27c 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -418,18 +418,9 @@
                 f"Expected a tuple but got {args[3].python_type()}",
             )
         operands = args[3].unpack_var_sequence(tx)
-        if not all(
-            isinstance(operand, (TensorVariable, torch.Tensor)) for operand in operands
-        ):
+        if not only_consist_of(args[3], (TensorVariable,)):
             unimplemented(
-                "Expected a tuple of tensors but got {actual_args}".format(  # noqa: UP032
-                    actual_args=[
-                        str(operand.python_type())
-                        if isinstance(operand, VariableTracker)
-                        else str(type(operand))
-                        for operand in operands
-                    ],
-                ),
+                "Expect operands to be a tuple of pytrees that only consists of tensor leaves."
             )
 
         # branches
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index dde1b37..96b522b 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -75,7 +75,7 @@
           have consistent input and outputs, meaning the inputs have to be
           the same, and the outputs have to be the same type and shape.
 
-        operands (Tuple[torch.Tensor]): A tuple of inputs to the true/false functions.
+        operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
 
     Example::
 
@@ -108,8 +108,6 @@
 
         - `cond` only supports **inference** right now. Autograd will be supported in the future.
 
-        - The **operands** must be a **tuple of tensors**. Pytree of tensors will be supported in the future.
-
         - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
 
     """
@@ -129,11 +127,12 @@
         if not callable(true_fn) or not callable(false_fn):
             raise RuntimeError("Expect both branches to be callbale.")
 
-        if not isinstance(operands, (tuple, list)) or any(
-            not isinstance(t, torch.Tensor) for t in operands
+        if not isinstance(operands, (tuple, list)) or pytree.tree_any(
+            lambda t: not isinstance(t, torch.Tensor), operands
         ):
             raise RuntimeError(
-                f"Expect operands to be a tuple of Tensors, but got {operands}."
+                "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
+                f"consists of tensor leaves, but got {operands}."
             )
 
     _validate_input(pred, true_fn, false_fn, operands)