Replace exir cond with torch cond
Summary: This diff removes exir.control_flow.cond and replace its existing usage with torch.cond.
Reviewed By: angelayi
Differential Revision: D47924374
fbshipit-source-id: d296479ee1f708cb423a27feb55e258632238902
diff --git a/exir/control_flow.py b/exir/control_flow.py
index 8b6d959..9426526 100644
--- a/exir/control_flow.py
+++ b/exir/control_flow.py
@@ -125,67 +125,6 @@
return gm
-def cond(
- pred: bool,
- true_fn: Callable[..., Tuple[torch.Tensor]],
- false_fn: Callable[..., Tuple[torch.Tensor]],
- inputs: pytree.PyTree,
-) -> Union[List[torch.Tensor], Value]:
- """
- A higher order function returning result based on passed predicate
- value and conditionally execute one of true_fn and false_fn.
-
- Detects whether a tracer is present in the context, and if so will
- trace_through both true_fn and false_fn with local inputs provided
- by tracing_context dictionary from the current tracer. When
- returning, wraps two traced graphs into a cond() call and construct
- a call_function node in the tracer's graph.
-
- Checks and enforces that the returning value(s) from both
- branches has the same Tensor type. For now enforces that both
- branches have the same number of tensor inputs.
- """
- flattened_inputs, _ = pytree.tree_flatten(inputs)
-
- if not all([isinstance(i, torch.Tensor) for i in flattened_inputs]):
- raise ExportError(
- ExportErrorType.INVALID_INPUT_TYPE,
- f"control_flow.cond() expects all inputs values to be tensors, actual inputs: {inputs}",
- )
-
- with using_tracer(None):
- outputs = true_fn(*inputs) if pred else false_fn(*inputs)
-
- flattened_outputs, _ = pytree.tree_flatten(outputs)
-
- if not all([isinstance(r, torch.Tensor) for r in flattened_outputs]):
- raise ExportError(
- ExportErrorType.INVALID_OUTPUT_TYPE,
- f"control_flow.cond() only supports tensors as output, actual output: {outputs}",
- )
-
- tracer = DispatchTracer.get()
-
- if tracer is None:
- return outputs
-
- # Once global tracer is present, we need to assume all tensors are
- # PythonTensor wrapped with FunctionalTensorWrapper.
-
- gm_true = _make_submodule(true_fn, example_returns=flattened_outputs)
- gm_false = _make_submodule(false_fn, example_returns=flattened_outputs)
- proxies = tuple([unwrap_proxy(i) for i in flattened_inputs])
-
- proxy = tracer.create_proxy(
- "call_function",
- cond,
- (unwrap_proxy(pred), gm_true, gm_false, proxies),
- {},
- )
-
- return tree_return(outputs, proxy, update_with_proxy)
-
-
def while_loop(
cond_fn: Callable[..., torch.Tensor],
body_fn: Callable[..., Tuple[torch.Tensor]],
diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py
index 4d4db54..eeaca5a 100644
--- a/exir/passes/__init__.py
+++ b/exir/passes/__init__.py
@@ -236,7 +236,7 @@
# pyre-ignore
to_out_var_skiplist: Set[Callable[[Any], Any]] = {
_operator.getitem,
- control_flow.cond,
+ torch.ops.higher_order.cond,
control_flow.while_loop,
# memory.alloc will be added after the to_out_variant pass so usually
# we won't see it in the input graph to the to_out_variant pass, unless
@@ -321,7 +321,7 @@
continue
target = node.target
- if target == control_flow.cond or target == torch.ops.higher_order.cond:
+ if target == torch.ops.higher_order.cond:
self.call(get_submodule(node.args[1]))
self.call(get_submodule(node.args[2]))
continue
diff --git a/test/end2end/test_end2end.py b/test/end2end/test_end2end.py
index 29b1c93..d278d0a 100644
--- a/test/end2end/test_end2end.py
+++ b/test/end2end/test_end2end.py
@@ -56,6 +56,7 @@
from executorch.exir.tests.dynamic_shape_models import BatchNormModel
from executorch.exir.tests.transformer import Transformer
+from functorch.experimental.control_flow import cond
kernel_mode = None # either aten mode or lean mode
try:
@@ -245,15 +246,13 @@
out = out + x
return out
- @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
def true_branch(c, x):
return addloop(x, 3)
- @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
def false_branch(c, x):
return addloop(x, 4)
- y = control_flow.cond(c, true_branch, false_branch, (c, x))
+ y = cond(c, true_branch, false_branch, (c, x))
return y * y
def get_random_inputs(self):
@@ -273,18 +272,14 @@
out = out + x
return out
- @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
def true_branch(c, x):
return addloop(x, 3)
- @control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
def false_branch(c, x):
return addloop(x, 4)
- # pyre-fixme[6]: Incompatible parameter type
- y = control_flow.cond(c, true_branch, false_branch, (c, x))
+ y = cond(c, true_branch, false_branch, (c, x))
- # pyre-fixme[58]: Unsupported operand type for binary operator '*'
return y * y
def get_random_inputs(self):
@@ -319,7 +314,7 @@
def false_branch(cnt):
return torch.zeros([1], dtype=torch.long)
- accum = accum + control_flow.cond(
+ accum = accum + cond(
torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
)
# 'cnt - 1' does not work yet since the runtime does not expect
@@ -372,9 +367,9 @@
def false_branch(accum, cnt):
return accum, cnt
- return control_flow.cond(
- torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt)
- )[0]
+ return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
+ 0
+ ]
def get_random_inputs(self):
return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))