[cond] support torch built in function as subgraph (#126909)
Fixes https://github.com/pytorch/pytorch/issues/126818.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126909
Approved by: https://github.com/zou3519
ghstack dependencies: #127026
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 92a988d..3cac101 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -873,6 +873,41 @@
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
+ def test_cond_accepts_torch_function_as_inputs(self):
+ a = torch.randn(3, 4)
+ b = torch.randn(3, 4)
+
+ def f(a, b):
+ return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
+
+ gm = self._check_tracing(f, (a, b))["symbolic"]
+ self.assertExpectedInline(
+ gm.code.strip(),
+ """\
+def forward(self, a_1, b_1):
+ sum_1 = torch.ops.aten.sum.default(a_1)
+ gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
+ true_graph_0 = self.true_graph_0
+ false_graph_0 = self.false_graph_0
+ conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
+ getitem = conditional[0]; conditional = None
+ return getitem""", # noqa: B950
+ )
+ self.assertExpectedInline(
+ gm.true_graph_0.code.strip(),
+ """\
+def forward(self, arg0_1, arg1_1):
+ add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
+ return (add,)""",
+ )
+ self.assertExpectedInline(
+ gm.false_graph_0.code.strip(),
+ """\
+def forward(self, arg0_1, arg1_1):
+ mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
+ return (mul,)""",
+ )
+
def test_cond_retrace_functionalized(self):
def true_fn(x):
return x.sin()
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 2514ae0..00932f9 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -30,7 +30,6 @@
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable
-from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
@@ -131,6 +130,14 @@
), "inputs to function body cannot alias outputs"
+def _check_supported_callable_arg(tx, func_var: VariableTracker, arg_name):
+ is_callable = (
+ BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant()
+ )
+ if not is_callable:
+ unimplemented(f"{arg_name} is of unsupported callable type {str(func_var)}.")
+
+
def validate_args_and_maybe_create_graph_inputs(
sub_args,
tracer,
@@ -567,12 +574,7 @@
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
- from . import (
- ListVariable,
- NestedUserFunctionVariable,
- TensorVariable,
- UserFunctionVariable,
- )
+ from . import ListVariable, TensorVariable
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
@@ -613,29 +615,8 @@
)
# branches
- assert isinstance(
- args[1],
- (
- UserFunctionVariable,
- NestedUserFunctionVariable,
- NNModuleVariable,
- UnspecializedNNModuleVariable,
- ),
- ), str(
- type(args[1])
- ) # true_fn
-
- assert isinstance(
- args[2],
- (
- UserFunctionVariable,
- NestedUserFunctionVariable,
- NNModuleVariable,
- UnspecializedNNModuleVariable,
- ),
- ), str(
- type(args[2])
- ) # false_fn
+ _check_supported_callable_arg(tx, args[1], "true_fn")
+ _check_supported_callable_arg(tx, args[2], "false_fn")
# Our strategy for tracing the true/false branches of cond
# are to checkpoint our graphstate, run the true branch,
@@ -806,7 +787,7 @@
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
- from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable
+ from . import TensorVariable
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
@@ -828,19 +809,8 @@
f"Usage: while_loop(cond_fn, body_fn, operands)",
)
- def _check_supported_callable(fn_var):
- assert isinstance(
- fn_var,
- (
- UserFunctionVariable,
- NestedUserFunctionVariable,
- NNModuleVariable,
- UnspecializedNNModuleVariable,
- ),
- ), str(type(fn_var))
-
- _check_supported_callable(args[0])
- _check_supported_callable(args[1])
+ _check_supported_callable_arg(tx, args[0], "cond_fn")
+ _check_supported_callable_arg(tx, args[1], "body_fn")
# operands
if not isinstance(args[2], (ListVariable, TupleVariable)):
@@ -1074,7 +1044,7 @@
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
- from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable
+ from . import TensorVariable
from .builder import wrap_fx_proxy_cls
if len(kwargs) > 0:
@@ -1082,10 +1052,8 @@
"torch.ops.higher_order.map: kwargs are not supported in the map operator."
)
- assert type(args[0].realize()) in (
- UserFunctionVariable,
- NestedUserFunctionVariable,
- )
+ _check_supported_callable_arg(tx, args[0].realize(), "map_fn")
+
assert type(args[1].realize()) is TensorVariable
sample_shape = get_fake_value(args[1].as_proxy().node, tx).size()