Replace exir cond with torch cond

Summary: This diff removes exir.control_flow.cond and replace its existing usage with torch.cond.

Reviewed By: angelayi

Differential Revision: D47924374

fbshipit-source-id: d296479ee1f708cb423a27feb55e258632238902
diff --git a/exir/control_flow.py b/exir/control_flow.py
index 8b6d959..9426526 100644
--- a/exir/control_flow.py
+++ b/exir/control_flow.py
@@ -125,67 +125,6 @@
     return gm
 
 
-def cond(
-    pred: bool,
-    true_fn: Callable[..., Tuple[torch.Tensor]],
-    false_fn: Callable[..., Tuple[torch.Tensor]],
-    inputs: pytree.PyTree,
-) -> Union[List[torch.Tensor], Value]:
-    """
-    A higher order function returning result based on passed predicate
-    value and conditionally execute one of true_fn and false_fn.
-
-    Detects whether a tracer is present in the context, and if so will
-    trace_through both true_fn and false_fn with local inputs provided
-    by tracing_context dictionary from the current tracer. When
-    returning, wraps two traced graphs into a cond() call and construct
-    a call_function node in the tracer's graph.
-
-    Checks and enforces that the returning value(s) from both
-    branches has the same Tensor type. For now enforces that both
-    branches have the same number of tensor inputs.
-    """
-    flattened_inputs, _ = pytree.tree_flatten(inputs)
-
-    if not all([isinstance(i, torch.Tensor) for i in flattened_inputs]):
-        raise ExportError(
-            ExportErrorType.INVALID_INPUT_TYPE,
-            f"control_flow.cond() expects all inputs values to be tensors, actual inputs: {inputs}",
-        )
-
-    with using_tracer(None):
-        outputs = true_fn(*inputs) if pred else false_fn(*inputs)
-
-    flattened_outputs, _ = pytree.tree_flatten(outputs)
-
-    if not all([isinstance(r, torch.Tensor) for r in flattened_outputs]):
-        raise ExportError(
-            ExportErrorType.INVALID_OUTPUT_TYPE,
-            f"control_flow.cond() only supports tensors as output, actual output: {outputs}",
-        )
-
-    tracer = DispatchTracer.get()
-
-    if tracer is None:
-        return outputs
-
-    # Once global tracer is present, we need to assume all tensors are
-    # PythonTensor wrapped with FunctionalTensorWrapper.
-
-    gm_true = _make_submodule(true_fn, example_returns=flattened_outputs)
-    gm_false = _make_submodule(false_fn, example_returns=flattened_outputs)
-    proxies = tuple([unwrap_proxy(i) for i in flattened_inputs])
-
-    proxy = tracer.create_proxy(
-        "call_function",
-        cond,
-        (unwrap_proxy(pred), gm_true, gm_false, proxies),
-        {},
-    )
-
-    return tree_return(outputs, proxy, update_with_proxy)
-
-
 def while_loop(
     cond_fn: Callable[..., torch.Tensor],
     body_fn: Callable[..., Tuple[torch.Tensor]],
diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py
index 4d4db54..eeaca5a 100644
--- a/exir/passes/__init__.py
+++ b/exir/passes/__init__.py
@@ -236,7 +236,7 @@
 # pyre-ignore
 to_out_var_skiplist: Set[Callable[[Any], Any]] = {
     _operator.getitem,
-    control_flow.cond,
+    torch.ops.higher_order.cond,
     control_flow.while_loop,
     # memory.alloc will be added after the to_out_variant pass so usually
     # we won't see it in the input graph to the to_out_variant pass, unless
@@ -321,7 +321,7 @@
                 continue
 
             target = node.target
-            if target == control_flow.cond or target == torch.ops.higher_order.cond:
+            if target == torch.ops.higher_order.cond:
                 self.call(get_submodule(node.args[1]))
                 self.call(get_submodule(node.args[2]))
                 continue
diff --git a/test/end2end/test_end2end.py b/test/end2end/test_end2end.py
index 29b1c93..d278d0a 100644
--- a/test/end2end/test_end2end.py
+++ b/test/end2end/test_end2end.py
@@ -56,6 +56,7 @@
 from executorch.exir.tests.dynamic_shape_models import BatchNormModel
 
 from executorch.exir.tests.transformer import Transformer
+from functorch.experimental.control_flow import cond
 
 kernel_mode = None  # either aten mode or lean mode
 try:
@@ -245,15 +246,13 @@
                 out = out + x
             return out
 
-        @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
         def true_branch(c, x):
             return addloop(x, 3)
 
-        @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
         def false_branch(c, x):
             return addloop(x, 4)
 
-        y = control_flow.cond(c, true_branch, false_branch, (c, x))
+        y = cond(c, true_branch, false_branch, (c, x))
         return y * y
 
     def get_random_inputs(self):
@@ -273,18 +272,14 @@
                 out = out + x
             return out
 
-        @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
         def true_branch(c, x):
             return addloop(x, 3)
 
-        @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
         def false_branch(c, x):
             return addloop(x, 4)
 
-        # pyre-fixme[6]: Incompatible parameter type
-        y = control_flow.cond(c, true_branch, false_branch, (c, x))
+        y = cond(c, true_branch, false_branch, (c, x))
 
-        # pyre-fixme[58]: Unsupported operand type for binary operator '*'
         return y * y
 
     def get_random_inputs(self):
@@ -319,7 +314,7 @@
             def false_branch(cnt):
                 return torch.zeros([1], dtype=torch.long)
 
-            accum = accum + control_flow.cond(
+            accum = accum + cond(
                 torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
             )
             # 'cnt - 1' does not work yet since the runtime does not expect
@@ -372,9 +367,9 @@
         def false_branch(accum, cnt):
             return accum, cnt
 
-        return control_flow.cond(
-            torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt)
-        )[0]
+        return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
+            0
+        ]
 
     def get_random_inputs(self):
         return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))