[cond] support torch built in function as subgraph (#126909)

Fixes https://github.com/pytorch/pytorch/issues/126818.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126909
Approved by: https://github.com/zou3519
ghstack dependencies: #127026
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 92a988d..3cac101 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -873,6 +873,41 @@
         )
         self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
 
+    def test_cond_accepts_torch_function_as_inputs(self):
+        a = torch.randn(3, 4)
+        b = torch.randn(3, 4)
+
+        def f(a, b):
+            return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
+
+        gm = self._check_tracing(f, (a, b))["symbolic"]
+        self.assertExpectedInline(
+            gm.code.strip(),
+            """\
+def forward(self, a_1, b_1):
+    sum_1 = torch.ops.aten.sum.default(a_1)
+    gt = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
+    true_graph_0 = self.true_graph_0
+    false_graph_0 = self.false_graph_0
+    conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]);  gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
+    getitem = conditional[0];  conditional = None
+    return getitem""",  # noqa: B950
+        )
+        self.assertExpectedInline(
+            gm.true_graph_0.code.strip(),
+            """\
+def forward(self, arg0_1, arg1_1):
+    add = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
+    return (add,)""",
+        )
+        self.assertExpectedInline(
+            gm.false_graph_0.code.strip(),
+            """\
+def forward(self, arg0_1, arg1_1):
+    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
+    return (mul,)""",
+        )
+
     def test_cond_retrace_functionalized(self):
         def true_fn(x):
             return x.sin()
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 2514ae0..00932f9 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -30,7 +30,6 @@
 from .dicts import ConstDictVariable
 from .lazy import LazyVariableTracker
 from .lists import ListVariable, TupleVariable
-from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
 
 if TYPE_CHECKING:
     from torch._dynamo.symbolic_convert import InstructionTranslator
@@ -131,6 +130,14 @@
     ), "inputs to function body cannot alias outputs"
 
 
+def _check_supported_callable_arg(tx, func_var: VariableTracker, arg_name):
+    is_callable = (
+        BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant()
+    )
+    if not is_callable:
+        unimplemented(f"{arg_name} is of unsupported callable type {str(func_var)}.")
+
+
 def validate_args_and_maybe_create_graph_inputs(
     sub_args,
     tracer,
@@ -567,12 +574,7 @@
     def call_function(
         self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
     ) -> "VariableTracker":
-        from . import (
-            ListVariable,
-            NestedUserFunctionVariable,
-            TensorVariable,
-            UserFunctionVariable,
-        )
+        from . import ListVariable, TensorVariable
 
         args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
 
@@ -613,29 +615,8 @@
             )
 
         # branches
-        assert isinstance(
-            args[1],
-            (
-                UserFunctionVariable,
-                NestedUserFunctionVariable,
-                NNModuleVariable,
-                UnspecializedNNModuleVariable,
-            ),
-        ), str(
-            type(args[1])
-        )  # true_fn
-
-        assert isinstance(
-            args[2],
-            (
-                UserFunctionVariable,
-                NestedUserFunctionVariable,
-                NNModuleVariable,
-                UnspecializedNNModuleVariable,
-            ),
-        ), str(
-            type(args[2])
-        )  # false_fn
+        _check_supported_callable_arg(tx, args[1], "true_fn")
+        _check_supported_callable_arg(tx, args[2], "false_fn")
 
         # Our strategy for tracing the true/false branches of cond
         # are to checkpoint our graphstate, run the true branch,
@@ -806,7 +787,7 @@
     def call_function(
         self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
     ) -> VariableTracker:
-        from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable
+        from . import TensorVariable
 
         args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
 
@@ -828,19 +809,8 @@
                 f"Usage: while_loop(cond_fn, body_fn, operands)",
             )
 
-        def _check_supported_callable(fn_var):
-            assert isinstance(
-                fn_var,
-                (
-                    UserFunctionVariable,
-                    NestedUserFunctionVariable,
-                    NNModuleVariable,
-                    UnspecializedNNModuleVariable,
-                ),
-            ), str(type(fn_var))
-
-        _check_supported_callable(args[0])
-        _check_supported_callable(args[1])
+        _check_supported_callable_arg(tx, args[0], "cond_fn")
+        _check_supported_callable_arg(tx, args[1], "body_fn")
 
         # operands
         if not isinstance(args[2], (ListVariable, TupleVariable)):
@@ -1074,7 +1044,7 @@
     def call_function(
         self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
     ) -> VariableTracker:
-        from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable
+        from . import TensorVariable
         from .builder import wrap_fx_proxy_cls
 
         if len(kwargs) > 0:
@@ -1082,10 +1052,8 @@
                 "torch.ops.higher_order.map: kwargs are not supported in the map operator."
             )
 
-        assert type(args[0].realize()) in (
-            UserFunctionVariable,
-            NestedUserFunctionVariable,
-        )
+        _check_supported_callable_arg(tx, args[0].realize(), "map_fn")
+
         assert type(args[1].realize()) is TensorVariable
 
         sample_shape = get_fake_value(args[1].as_proxy().node, tx).size()