Change to torch.ops.higher_order.cond in verifier (#108302)
We need to match against torch.ops.higher_order.cond in verifier.
Test Plan:
added test to guard against change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108302
Approved by: https://github.com/angelayi
diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py
index 88e1c48..c0ba22b 100644
--- a/test/export/test_verifier.py
+++ b/test/export/test_verifier.py
@@ -7,7 +7,7 @@
import torch.nn as nn
import torch._dynamo as torchdynamo
from functorch import make_fx
-from functorch.experimental import functionalize
+from functorch.experimental import functionalize, control_flow
from torch import Tensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.eval_frame import is_dynamo_supported
@@ -108,6 +108,11 @@
y = self.dropout2(y)
return y
+class ControlFlow(nn.Module):
+
+ def forward(self, pred: Tensor, x: Tensor) -> Tensor:
+ return control_flow.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,))
+
class VerifierTest(TestCase):
@@ -206,6 +211,14 @@
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
+ @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+ def test_verifier_control_flow_success(self) -> None:
+ m = ControlFlow()
+ gm = torch._export.export(m, (torch.tensor(True), torch.randn(3, 4))).graph_module
+ # No error should be raised
+ verifier = ATenDialectVerifier()
+ verifier(gm)
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py
index 7e7d2b2..0bbf5c2 100644
--- a/torch/_export/verifier.py
+++ b/torch/_export/verifier.py
@@ -4,7 +4,6 @@
from typing import Set
import torch
-from functorch.experimental import control_flow
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
@@ -61,7 +60,7 @@
def valid_builtin_funcs(self):
return [
operator.getitem,
- control_flow.cond,
+ torch.ops.higher_order.cond,
torch.ops.map_impl,
]