Revert "[HigherOrderOp] Should automatically pop modes (#109157)"

This reverts commit f03b8abd4706e53b3fb6aefbd4304884e537616d.

Reverted https://github.com/pytorch/pytorch/pull/109157 on behalf of https://github.com/clee2000 due to broke internal builds D49346922 ([comment](https://github.com/pytorch/pytorch/pull/109157#issuecomment-1722571262))
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 138d8d3..09a17d9 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -23,6 +23,10 @@
     track_tensor_tree,
 )
 from torch.multiprocessing.reductions import StorageWeakRef
+from torch.utils._python_dispatch import (
+    _get_current_dispatch_mode,
+    _pop_mode_temporarily,
+)
 
 
 # TODO: We add this to prevent dymamo from tracing into map_wrapper,
@@ -307,17 +311,19 @@
 
 
 @map_impl.py_impl(ProxyTorchDispatchMode)
-def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
-    if mode.enable_tracing:
-        return trace_map(mode, map_impl, f, num_mapped, *args)
-    else:
-        return map_impl(f, num_mapped, *args)
+def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
+    mode = _get_current_dispatch_mode()
+    assert mode is not None, "Mode should always be enabled for python fallback key"
+    with _pop_mode_temporarily() as mode:
+        if mode.enable_tracing:
+            return trace_map(mode, map_impl, f, num_mapped, *args)
+        else:
+            return map_impl(f, num_mapped, *args)
 
 
 @map_impl.py_impl(FakeTensorMode)
-def map_fake_tensor_mode(mode, f, num_mapped, *args):
-    with mode:
-        return map_dense(f, num_mapped, *args)
+def map_fake_tensor_mode(f, num_mapped, *args):
+    return map_dense(f, num_mapped, *args)
 
 
 @map_impl.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py
index 956d2ca..c26a6ec 100644
--- a/torch/_export/wrappers.py
+++ b/torch/_export/wrappers.py
@@ -8,6 +8,10 @@
 from torch._subclasses.fake_tensor import FakeTensorMode
 from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
 from torch.utils import _pytree as pytree
+from torch.utils._python_dispatch import (
+    _get_current_dispatch_mode,
+    _pop_mode_temporarily,
+)
 
 
 _export_tracepoint = HigherOrderOperator("_export_tracepoint")
@@ -22,20 +26,22 @@
 
 
 @_export_tracepoint.py_impl(ProxyTorchDispatchMode)
-def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
-    if not mode.enable_tracing:
-        return _export_tracepoint(*args, **kwargs)
-    p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
-    proxy = mode.tracer.create_proxy(
-        "call_function", _export_tracepoint, p_args, p_kwargs
-    )
-    return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
+def export_tracepoint_dispatch_mode(*args, **kwargs):
+    mode = _get_current_dispatch_mode()
+    assert mode is not None, "Mode should always be enabled for python fallback key"
+    with _pop_mode_temporarily() as mode:
+        if not mode.enable_tracing:
+            return _export_tracepoint(*args, **kwargs)
+        p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
+        proxy = mode.tracer.create_proxy(
+            "call_function", _export_tracepoint, p_args, p_kwargs
+        )
+        return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
 
 
 @_export_tracepoint.py_impl(FakeTensorMode)
-def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
-    with mode:
-        return args
+def export_tracepoint_fake_tensor_mode(*args, **kwargs):
+    return args
 
 
 @_export_tracepoint.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 46ee261..b0fb1a5 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -26,7 +26,10 @@
 )
 from torch.fx.passes.shape_prop import _extract_tensor_metadata
 from torch.multiprocessing.reductions import StorageWeakRef
-from torch.utils._python_dispatch import _get_current_dispatch_mode
+from torch.utils._python_dispatch import (
+    _get_current_dispatch_mode,
+    _pop_mode_temporarily,
+)
 
 
 @contextmanager
@@ -278,19 +281,35 @@
 
 
 @cond_op.py_impl(ProxyTorchDispatchMode)
-def inner(mode, pred, true_fn, false_fn, operands):
-    if mode.enable_tracing:
-        return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
-    else:
-        return cond_op(pred, true_fn, false_fn, operands)
+def inner(pred, true_fn, false_fn, operands):
+    # TODO Move this to proper utility function
+    from torch._ops import mode_stack_per_key, temporarily_pop_mode
+
+    # torch.cond expects ProxyTorchDispatchMode to **still** be on the stack
+    # at the time that its proxy implementation is called.
+    # However, the mode can live in one of two places, depending on
+    # whether we're doing pre_dispatch tracing or normal tracing.
+    pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, [])  # type: ignore[attr-defined]
+    if len(pre_dispatch_modes) > 0:
+        with temporarily_pop_mode(pre_dispatch_modes) as mode:
+            if mode.enable_tracing:
+                return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
+            else:
+                return cond_op(pred, true_fn, false_fn, operands)
+    mode = _get_current_dispatch_mode()
+    assert mode is not None, "Mode should always be enabled for python fallback key"
+    with _pop_mode_temporarily() as mode:
+        if mode.enable_tracing:
+            return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
+        else:
+            return cond_op(pred, true_fn, false_fn, operands)
 
 
 @cond_op.py_impl(FakeTensorMode)
-def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
-    with mode:
-        true_outs = true_fn(*operands)
-        flat_true_outs, _ = pytree.tree_flatten(true_outs)
-        flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
+def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
+    true_outs = true_fn(*operands)
+    flat_true_outs, _ = pytree.tree_flatten(true_outs)
+    flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
     if len(flat_true_outs) != len(flat_false_outs):
         raise RuntimeError("Unmatched number of outputs from cond() branches.")
 
diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py
index 71faf78..b2e50ad 100644
--- a/torch/_higher_order_ops/out_dtype.py
+++ b/torch/_higher_order_ops/out_dtype.py
@@ -7,6 +7,10 @@
     track_tensor_tree,
     maybe_handle_decomp,
 )
+from torch.utils._python_dispatch import (
+    _get_current_dispatch_mode,
+    _pop_mode_temporarily,
+)
 from torch._C import DispatchKey, _ExcludeDispatchKeyGuard, DispatchKeySet
 from torch._functorch.eager_transforms import (
     _unwrap_all_tensors_from_functional,
@@ -155,26 +159,36 @@
 
 @out_dtype.py_impl(ProxyTorchDispatchMode)
 def out_dtype_proxy(
-    mode: ProxyTorchDispatchMode,
     op: torch._ops.OpOverload,
     output_dtype: torch.dtype,
     *args
 ):
-    if mode.enable_tracing:
-        return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
-    else:
-        return out_dtype(op, output_dtype, *args)
+    # TODO Move this to proper utility function
+    from torch._ops import mode_stack_per_key, temporarily_pop_mode
+    pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, [])  # type: ignore[attr-defined]
+    if len(pre_dispatch_modes) > 0:
+        with temporarily_pop_mode(pre_dispatch_modes) as mode:
+            if mode.enable_tracing:
+                return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
+            else:
+                return out_dtype(op, output_dtype, *args)
+
+    mode = _get_current_dispatch_mode()
+    assert (mode is not None), "Mode should always be enabled for python fallback key"
+    with _pop_mode_temporarily() as mode:
+        if mode.enable_tracing:
+            return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
+        else:
+            return out_dtype(op, output_dtype, *args)
 
 
 @out_dtype.py_impl(FakeTensorMode)
 def out_dtype_fake_tensor_mode(
-    mode: FakeTensorMode,
     op: torch._ops.OpOverload,
     output_dtype: torch.dtype,
     *args
 ):
-    with mode:
-        return out_dtype_dense(op, output_dtype, *args)
+    return out_dtype_dense(op, output_dtype, *args)
 
 
 @out_dtype.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_ops.py b/torch/_ops.py
index 9c5b8d7..4c0bd5d 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -243,9 +243,7 @@
             return dispatch_functorch(self, args, kwargs)
 
         if dispatch_key == torch._C.DispatchKey.Python:
-            # The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc
-            from torch.utils._python_dispatch import _pop_mode_temporarily
-
+            # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
             curr_mode = _get_current_dispatch_mode()
             assert (
                 curr_mode is not None
@@ -253,13 +251,11 @@
             assert (
                 type(curr_mode) in self.python_key_mode_table
             ), f"Current active mode {curr_mode} not registered"
-            handler = self.python_key_mode_table[type(curr_mode)]
-            with _pop_mode_temporarily() as mode:
-                return handler(mode, *args, **kwargs)
+            # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
+            return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
 
         functionality_key = torch._C._to_functionality_key(dispatch_key)  # type: ignore[attr-defined]
         if functionality_key in mode_stack_per_key():
-            # The place to handle DispatchKey.PreDispatch
             curr_stack = mode_stack_per_key()[functionality_key]
             # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
             # calls inside of a mode.
@@ -269,13 +265,7 @@
                 DispatchKey.Python
             ):
                 curr_mode = curr_stack[-1]
-                pre_dispatch_modes = mode_stack_per_key().get(
-                    DispatchKey.PreDispatch, []  # type: ignore[attr-defined]
-                )
-                handler = self.python_key_mode_table[type(curr_mode)]
-                if len(pre_dispatch_modes) > 0:
-                    with temporarily_pop_mode(pre_dispatch_modes) as mode:
-                        return handler(mode, *args, **kwargs)
+                return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
 
         final_key = resolve_key(self, dispatch_key)
 
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index 8f14419..7dd16db 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -16,6 +16,10 @@
     track_tensor_tree,
 )
 from torch.types import _device, _dtype
+from torch.utils._python_dispatch import (
+    _get_current_dispatch_mode,
+    _pop_mode_temporarily,
+)
 
 
 rngprim_namespace = "rngprims"
@@ -188,23 +192,27 @@
         return impl(op, *args, **kwargs)
 
     @run_and_save_rng_state.py_impl(FakeTensorMode)
-    def impl_fake_tensor_mode(mode, op, *args, **kwargs):
+    def impl_fake_tensor_mode(op, *args, **kwargs):
         # Check device to call the right impl
-        with mode:
-            return impl_backend_select(op, *args, **kwargs)
+        return impl_backend_select(op, *args, **kwargs)
 
     @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
-    def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
-        if mode.enable_tracing:
-            out = impl_backend_select(op, *args, **kwargs)
-            proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
-            proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
-            out_proxy = mode.tracer.create_proxy(
-                "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
-            )
-            return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
-        else:
-            return run_and_save_rng_state(op, *args, **kwargs)
+    def impl_proxy_dispatch_mode(op, *args, **kwargs):
+        mode = _get_current_dispatch_mode()
+        assert mode is not None
+        with _pop_mode_temporarily() as mode:
+            if mode.enable_tracing:
+                out = impl_fake_tensor_mode(op, *args, **kwargs)
+                proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
+                proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
+                out_proxy = mode.tracer.create_proxy(
+                    "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
+                )
+                return track_tensor_tree(
+                    out, out_proxy, constant=None, tracer=mode.tracer
+                )
+            else:
+                return run_and_save_rng_state(op, *args, **kwargs)
 
     return run_and_save_rng_state
 
@@ -239,20 +247,25 @@
         return out
 
     @run_with_rng_state.py_impl(ProxyTorchDispatchMode)
-    def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
-        if mode.enable_tracing:
-            with disable_proxy_modes_tracing():
-                out = run_with_rng_state(rng_state, op, *args, **kwargs)
-            proxy_args = pytree.tree_map(
-                mode.tracer.unwrap_proxy, (rng_state, op, *args)
-            )
-            proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
-            out_proxy = mode.tracer.create_proxy(
-                "call_function", run_with_rng_state, proxy_args, proxy_kwargs
-            )
-            return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
-        else:
-            return run_with_rng_state(rng_state, op, *args, **kwargs)
+    def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
+        mode = _get_current_dispatch_mode()
+        assert mode is not None
+        with _pop_mode_temporarily() as mode:
+            if mode.enable_tracing:
+                with disable_proxy_modes_tracing():
+                    out = run_with_rng_state(rng_state, op, *args, **kwargs)
+                proxy_args = pytree.tree_map(
+                    mode.tracer.unwrap_proxy, (rng_state, op, *args)
+                )
+                proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
+                out_proxy = mode.tracer.create_proxy(
+                    "call_function", run_with_rng_state, proxy_args, proxy_kwargs
+                )
+                return track_tensor_tree(
+                    out, out_proxy, constant=None, tracer=mode.tracer
+                )
+            else:
+                return run_with_rng_state(rng_state, op, *args, **kwargs)
 
     @run_with_rng_state.py_impl(DispatchKey.BackendSelect)
     def impl_backend_select(rng_state, op, *args, **kwargs):
@@ -263,11 +276,10 @@
         return impl(rng_state, op, *args, **kwargs)
 
     @run_with_rng_state.py_impl(FakeTensorMode)
-    def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
+    def impl_fake_tensor_mode(rng_state, op, *args, **kwargs):
         # Skip setting the set_rng_state as it does not work well with fake tensors.
         # And it does not matter for the fake tensor mode.
-        with mode:
-            return op(*args, **kwargs)
+        return op(*args, **kwargs)
 
     return run_with_rng_state