[HigherOrderOp] add pytree operands tests for cond (#112661)
This is a follow-up of #111611. After this PR, we allow pytree with tensor-only leaves as operands of branches.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112661
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 2653f08..afc8dd6 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -3014,7 +3014,7 @@
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
RuntimeError,
- "Expect operands to be a tuple of Tensors, but got",
+ r"Expect operands to be a tuple of possibly nested dict/list/tuple",
):
f_non_list_operands(*example_inputs)
@@ -3027,7 +3027,8 @@
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
- RuntimeError, "Expect operands to be a tuple of Tensors, but got"
+ RuntimeError,
+ r"Expect operands to be a tuple of possibly nested dict/list/tuple",
):
f_non_tensor_operands(*example_inputs)
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index bfa29ad..e3c7ede 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -2122,6 +2122,93 @@
'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
)
+ def test_cond_pytree_operands(self):
+ def _construct_pytree():
+ a = torch.randn(3, 3)
+ b = torch.randn(3, 3)
+ c = torch.randn(3, 3)
+ d = torch.randn(3, 3)
+ e = torch.randn(3, 3)
+ f = torch.randn(3, 3)
+ g = torch.randn(3, 3)
+ return (a, [[[b]]], c, (d, (e,), f), {"g": g})
+
+ pred = torch.tensor(True)
+ inp = _construct_pytree()
+
+ def _reduce_sum(flattened):
+ init = 0
+ for val in flattened:
+ init += val
+ return init
+
+ def _reduce_max(flattened):
+ init = flattened[0]
+ for val in flattened:
+ init = max(val, init)
+ return init
+
+ def true_fn(pytree_in):
+ flattened, spec = pytree.tree_flatten(pytree_in)
+ return _reduce_sum(flattened)
+
+ def false_fn(pytree_in):
+ flattened, spec = pytree.tree_flatten(pytree_in)
+ return _reduce_max(flattened)
+
+ def fn(pred, pytree_in):
+ return torch.cond(pred, true_fn, false_fn, [pytree_in])
+
+ backend = EagerAndRecordGraphs()
+ cnt = CompileCounterWithBackend(backend)
+ compiled_res = torch.compile(fn, backend=backend)(pred, inp)
+ eager_res = fn(pred, inp)
+ self.assertEqual(compiled_res, eager_res)
+ graph = backend.graphs[0]
+
+ # Dynamic shapes produce a slightly different graph.
+ if check_dynamic_shape_capture():
+ return
+
+ self.assertExpectedInline(
+ graph.code.strip(),
+ """\
+def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
+ l_pred_ = L_pred_
+ l_pytree_in_0_ = L_pytree_in_0_
+ l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
+ l_pytree_in_2_ = L_pytree_in_2_
+ l_pytree_in_3_0_ = L_pytree_in_3_0_
+ l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
+ l_pytree_in_3_2_ = L_pytree_in_3_2_
+ l_pytree_in_4_g_ = L_pytree_in_4_g_
+ cond_true_0 = self.cond_true_0
+ cond_false_0 = self.cond_false_0
+ cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
+ return (cond,)""", # noqa: B950
+ )
+
+ def test_cond_pytree_operands_with_non_tensor_leaves(self):
+ def fn(pred, pytree_in):
+ return torch.cond(
+ pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,)
+ )
+
+ pred = torch.tensor(True)
+ for pytree_in in [(1,), ("string",), (1.0,)]:
+ with self.assertRaisesRegex(
+ RuntimeError,
+ r"Expect operands to be a tuple of possibly nested dict/list/tuple",
+ ):
+ fn(pred, pytree_in)
+
+ for pytree_in in [(1,), ("string",), (1.0,)]:
+ with self.assertRaisesRegex(
+ torch._dynamo.exc.UncapturedHigherOrderOpError,
+ r"Cond doesn't work unless it is captured completely with torch.compile",
+ ):
+ torch.compile(fn, backend="eager")(pred, pytree_in)
+
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def run(self, result=None):
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index fbe5a77..8acb27c 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -418,18 +418,9 @@
f"Expected a tuple but got {args[3].python_type()}",
)
operands = args[3].unpack_var_sequence(tx)
- if not all(
- isinstance(operand, (TensorVariable, torch.Tensor)) for operand in operands
- ):
+ if not only_consist_of(args[3], (TensorVariable,)):
unimplemented(
- "Expected a tuple of tensors but got {actual_args}".format( # noqa: UP032
- actual_args=[
- str(operand.python_type())
- if isinstance(operand, VariableTracker)
- else str(type(operand))
- for operand in operands
- ],
- ),
+ "Expect operands to be a tuple of pytrees that only consists of tensor leaves."
)
# branches
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index dde1b37..96b522b 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -75,7 +75,7 @@
have consistent input and outputs, meaning the inputs have to be
the same, and the outputs have to be the same type and shape.
- operands (Tuple[torch.Tensor]): A tuple of inputs to the true/false functions.
+ operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
Example::
@@ -108,8 +108,6 @@
- `cond` only supports **inference** right now. Autograd will be supported in the future.
- - The **operands** must be a **tuple of tensors**. Pytree of tensors will be supported in the future.
-
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
"""
@@ -129,11 +127,12 @@
if not callable(true_fn) or not callable(false_fn):
raise RuntimeError("Expect both branches to be callbale.")
- if not isinstance(operands, (tuple, list)) or any(
- not isinstance(t, torch.Tensor) for t in operands
+ if not isinstance(operands, (tuple, list)) or pytree.tree_any(
+ lambda t: not isinstance(t, torch.Tensor), operands
):
raise RuntimeError(
- f"Expect operands to be a tuple of Tensors, but got {operands}."
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
+ f"consists of tensor leaves, but got {operands}."
)
_validate_input(pred, true_fn, false_fn, operands)