[HigherOrderOp] expose torch.cond (#110293)

This pr expose torch._higher_order_ops.cond as torch.cond.

1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
diff --git a/docs/source/control_flow_cond.rst b/docs/source/cond.rst
similarity index 100%
rename from docs/source/control_flow_cond.rst
rename to docs/source/cond.rst
diff --git a/docs/source/export.rst b/docs/source/export.rst
index ead1849..723f4a6 100644
--- a/docs/source/export.rst
+++ b/docs/source/export.rst
@@ -501,7 +501,7 @@
 x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
 possibly deal with without generating code for a combinatorially exploding
 number of paths. In such cases, users will need to rewrite their code using
-special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
+special control flow operators. Currently, we support :ref:`torch.cond <cond>`
 to express if-else like control flow (more coming soon!).
 
 Data-Dependent Accesses
@@ -540,7 +540,7 @@
    torch.compiler_transformations
    torch.compiler_ir
    generated/exportdb/index
-   control_flow_cond
+   cond
 
 .. toctree::
    :caption: Deep Dive for PyTorch Developers
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 4df0844..389ea69 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -718,6 +718,17 @@
     export
     generated/exportdb/index
 
+Control Flow
+------------
+.. autosummary::
+    :toctree: generated
+    :nosignatures:
+
+.. warning::
+    This feature is a prototype and may have compatibility breaking changes in the future.
+
+    cond
+
 Optimizations
 -------------
 .. autosummary::
diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py
index ddfdd69..cb6ff2e 100644
--- a/functorch/experimental/control_flow.py
+++ b/functorch/experimental/control_flow.py
@@ -1,6 +1,4 @@
-from torch._higher_order_ops.cond import (  # noqa: F401
-    cond,
-    UnsupportedAliasMutationException,
-)
+from torch import cond  # noqa: F401
+from torch._higher_order_ops.cond import UnsupportedAliasMutationException  # noqa: F401
 
 from ._map import map  # noqa: F401
diff --git a/torch/__init__.py b/torch/__init__.py
index 74662ed..e13fa1c 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -56,7 +56,7 @@
     'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
     'SymBool', 'sym_not',
     'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
-    'export', 'autocast',
+    'export', 'autocast', 'cond',
 ]
 
 ################################################################################
@@ -986,7 +986,7 @@
 # These error checking functions must be kept consistent with their C++
 # equivalents. Their C++ equivalents are mentioned where applicable.
 
-def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
+def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):  # noqa: F811
     if not isinstance(cond, (builtins.bool, torch.SymBool)):
         raise TypeError(f'cond must be a bool, but got {type(cond)}')
 
@@ -1010,7 +1010,7 @@
 
     raise error_type(message_evaluated)
 
-def _check(cond, message=None):
+def _check(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1041,7 +1041,7 @@
     _check(i >= 0, message)
     torch.fx.experimental.symbolic_shapes._advise_is_size(i)
 
-def _check_index(cond, message=None):
+def _check_index(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1058,7 +1058,7 @@
     """
     _check_with(IndexError, cond, message)
 
-def _check_value(cond, message=None):
+def _check_value(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1075,7 +1075,7 @@
     """
     _check_with(ValueError, cond, message)
 
-def _check_type(cond, message=None):
+def _check_type(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1092,7 +1092,7 @@
     """
     _check_with(TypeError, cond, message)
 
-def _check_not_implemented(cond, message=None):
+def _check_not_implemented(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1109,7 +1109,7 @@
     """
     _check_with(NotImplementedError, cond, message)
 
-def _check_tensor_all_with(error_type, cond, message=None):
+def _check_tensor_all_with(error_type, cond, message=None):  # noqa: F811
     if not torch.is_tensor(cond):
         raise TypeError(f'cond must be a tensor, but got {type(cond)}')
 
@@ -1120,7 +1120,7 @@
     _check_with(error_type, cond._is_all_true().item(), message)
 
 # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
-def _check_tensor_all(cond, message=None):
+def _check_tensor_all(cond, message=None):  # noqa: F811
     r"""Throws error containing an optional message if the specified condition
     is False.
 
@@ -1761,6 +1761,7 @@
 
 from torch import export as export
 
+from torch._higher_order_ops import cond
 
 def _register_device_module(device_type, module):
     r"""Register an external runtime module of the specific :attr:`device_type`
diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py
index 8beca2b..0c1c529 100644
--- a/torch/_dynamo/allowed_functions.py
+++ b/torch/_dynamo/allowed_functions.py
@@ -215,6 +215,7 @@
                     torch.func.vmap,
                     deprecated_func.vmap,
                     torch.nn.functional.triplet_margin_with_distance_loss,
+                    torch.cond,
                 ):
                     continue
 
diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py
index e69de29..2ac132d 100644
--- a/torch/_higher_order_ops/__init__.py
+++ b/torch/_higher_order_ops/__init__.py
@@ -0,0 +1 @@
+from .cond import cond
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 0a35839..c984e3d 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -8,8 +8,7 @@
 import torch.utils._pytree as pytree
 
 from torch._C import DispatchKey
-from torch._dynamo.exc import CondOpArgsMismatchError
-from torch._dynamo.utils import disable_cache_limit
+from torch._functorch.utils import exposed_in
 
 from torch._higher_order_ops.utils import autograd_not_implemented
 from torch._ops import HigherOrderOperator
@@ -42,6 +41,7 @@
     reason: str
 
 
+@exposed_in("torch")
 def cond(pred, true_fn, false_fn, operands):
     r"""
     Conditionally applies `true_fn` or `false_fn`.
@@ -142,7 +142,7 @@
         raise RuntimeError("torch.cond requires dynamo support.")
 
     with _set_compilation_env():
-        with disable_cache_limit():
+        with torch._dynamo.utils.disable_cache_limit():
             return torch.compile(cond_op, backend="eager", fullgraph=True)(
                 pred, true_fn, false_fn, operands
             )
@@ -198,7 +198,7 @@
     flat_true_outs, _ = pytree.tree_flatten(true_outs)
     flat_false_outs, _ = pytree.tree_flatten(false_outs)
     if len(flat_true_outs) != len(flat_false_outs):
-        raise CondOpArgsMismatchError(
+        raise torch._dynamo.exc.CondOpArgsMismatchError(
             f"Expected to return same number of outputs but got:"
             f"\n  {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
             f"\n  {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
@@ -208,7 +208,7 @@
         true_out = flat_true_outs[i]
         false_out = flat_false_outs[i]
         if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
-            raise CondOpArgsMismatchError(
+            raise torch._dynamo.exc.CondOpArgsMismatchError(
                 f"Expected each tensor to have same metadata but got:"
                 f"\n  {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
                 f"\n  {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
@@ -291,7 +291,7 @@
         true_meta = _extract_tensor_metadata(true_out)
         false_meta = _extract_tensor_metadata(false_out)
         if true_meta != false_meta:
-            raise CondOpArgsMismatchError(
+            raise torch._dynamo.exc.CondOpArgsMismatchError(
                 f"Expected each tensor to have same metadata but got:"
                 f"\n  {true_fn.__name__} returns {true_meta}"
                 f"\n  {false_fn.__name__} returns {false_meta}"
diff --git a/torch/overrides.py b/torch/overrides.py
index a771d60..4793793 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -297,6 +297,7 @@
         torch.set_vital,
         torch.read_vitals,
         torch.vmap,
+        torch.cond,
         torch.frombuffer,
         torch.asarray,
         torch._functional_sym_constrain_range,