Revert "[HigherOrderOp] Should automatically pop modes (#109157)"
This reverts commit f03b8abd4706e53b3fb6aefbd4304884e537616d.
Reverted https://github.com/pytorch/pytorch/pull/109157 on behalf of https://github.com/clee2000 due to broke internal builds D49346922 ([comment](https://github.com/pytorch/pytorch/pull/109157#issuecomment-1722571262))
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 138d8d3..09a17d9 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -23,6 +23,10 @@
track_tensor_tree,
)
from torch.multiprocessing.reductions import StorageWeakRef
+from torch.utils._python_dispatch import (
+ _get_current_dispatch_mode,
+ _pop_mode_temporarily,
+)
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
@@ -307,17 +311,19 @@
@map_impl.py_impl(ProxyTorchDispatchMode)
-def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
- if mode.enable_tracing:
- return trace_map(mode, map_impl, f, num_mapped, *args)
- else:
- return map_impl(f, num_mapped, *args)
+def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
+ mode = _get_current_dispatch_mode()
+ assert mode is not None, "Mode should always be enabled for python fallback key"
+ with _pop_mode_temporarily() as mode:
+ if mode.enable_tracing:
+ return trace_map(mode, map_impl, f, num_mapped, *args)
+ else:
+ return map_impl(f, num_mapped, *args)
@map_impl.py_impl(FakeTensorMode)
-def map_fake_tensor_mode(mode, f, num_mapped, *args):
- with mode:
- return map_dense(f, num_mapped, *args)
+def map_fake_tensor_mode(f, num_mapped, *args):
+ return map_dense(f, num_mapped, *args)
@map_impl.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py
index 956d2ca..c26a6ec 100644
--- a/torch/_export/wrappers.py
+++ b/torch/_export/wrappers.py
@@ -8,6 +8,10 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils import _pytree as pytree
+from torch.utils._python_dispatch import (
+ _get_current_dispatch_mode,
+ _pop_mode_temporarily,
+)
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
@@ -22,20 +26,22 @@
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
-def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
- if not mode.enable_tracing:
- return _export_tracepoint(*args, **kwargs)
- p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
- proxy = mode.tracer.create_proxy(
- "call_function", _export_tracepoint, p_args, p_kwargs
- )
- return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
+def export_tracepoint_dispatch_mode(*args, **kwargs):
+ mode = _get_current_dispatch_mode()
+ assert mode is not None, "Mode should always be enabled for python fallback key"
+ with _pop_mode_temporarily() as mode:
+ if not mode.enable_tracing:
+ return _export_tracepoint(*args, **kwargs)
+ p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
+ proxy = mode.tracer.create_proxy(
+ "call_function", _export_tracepoint, p_args, p_kwargs
+ )
+ return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
@_export_tracepoint.py_impl(FakeTensorMode)
-def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
- with mode:
- return args
+def export_tracepoint_fake_tensor_mode(*args, **kwargs):
+ return args
@_export_tracepoint.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 46ee261..b0fb1a5 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -26,7 +26,10 @@
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.multiprocessing.reductions import StorageWeakRef
-from torch.utils._python_dispatch import _get_current_dispatch_mode
+from torch.utils._python_dispatch import (
+ _get_current_dispatch_mode,
+ _pop_mode_temporarily,
+)
@contextmanager
@@ -278,19 +281,35 @@
@cond_op.py_impl(ProxyTorchDispatchMode)
-def inner(mode, pred, true_fn, false_fn, operands):
- if mode.enable_tracing:
- return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
- else:
- return cond_op(pred, true_fn, false_fn, operands)
+def inner(pred, true_fn, false_fn, operands):
+ # TODO Move this to proper utility function
+ from torch._ops import mode_stack_per_key, temporarily_pop_mode
+
+ # torch.cond expects ProxyTorchDispatchMode to **still** be on the stack
+ # at the time that its proxy implementation is called.
+ # However, the mode can live in one of two places, depending on
+ # whether we're doing pre_dispatch tracing or normal tracing.
+ pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, []) # type: ignore[attr-defined]
+ if len(pre_dispatch_modes) > 0:
+ with temporarily_pop_mode(pre_dispatch_modes) as mode:
+ if mode.enable_tracing:
+ return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
+ else:
+ return cond_op(pred, true_fn, false_fn, operands)
+ mode = _get_current_dispatch_mode()
+ assert mode is not None, "Mode should always be enabled for python fallback key"
+ with _pop_mode_temporarily() as mode:
+ if mode.enable_tracing:
+ return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
+ else:
+ return cond_op(pred, true_fn, false_fn, operands)
@cond_op.py_impl(FakeTensorMode)
-def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
- with mode:
- true_outs = true_fn(*operands)
- flat_true_outs, _ = pytree.tree_flatten(true_outs)
- flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
+def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
+ true_outs = true_fn(*operands)
+ flat_true_outs, _ = pytree.tree_flatten(true_outs)
+ flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
if len(flat_true_outs) != len(flat_false_outs):
raise RuntimeError("Unmatched number of outputs from cond() branches.")
diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py
index 71faf78..b2e50ad 100644
--- a/torch/_higher_order_ops/out_dtype.py
+++ b/torch/_higher_order_ops/out_dtype.py
@@ -7,6 +7,10 @@
track_tensor_tree,
maybe_handle_decomp,
)
+from torch.utils._python_dispatch import (
+ _get_current_dispatch_mode,
+ _pop_mode_temporarily,
+)
from torch._C import DispatchKey, _ExcludeDispatchKeyGuard, DispatchKeySet
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
@@ -155,26 +159,36 @@
@out_dtype.py_impl(ProxyTorchDispatchMode)
def out_dtype_proxy(
- mode: ProxyTorchDispatchMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
- if mode.enable_tracing:
- return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
- else:
- return out_dtype(op, output_dtype, *args)
+ # TODO Move this to proper utility function
+ from torch._ops import mode_stack_per_key, temporarily_pop_mode
+ pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, []) # type: ignore[attr-defined]
+ if len(pre_dispatch_modes) > 0:
+ with temporarily_pop_mode(pre_dispatch_modes) as mode:
+ if mode.enable_tracing:
+ return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
+ else:
+ return out_dtype(op, output_dtype, *args)
+
+ mode = _get_current_dispatch_mode()
+ assert (mode is not None), "Mode should always be enabled for python fallback key"
+ with _pop_mode_temporarily() as mode:
+ if mode.enable_tracing:
+ return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
+ else:
+ return out_dtype(op, output_dtype, *args)
@out_dtype.py_impl(FakeTensorMode)
def out_dtype_fake_tensor_mode(
- mode: FakeTensorMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
- with mode:
- return out_dtype_dense(op, output_dtype, *args)
+ return out_dtype_dense(op, output_dtype, *args)
@out_dtype.py_impl(DispatchKey.Functionalize)
diff --git a/torch/_ops.py b/torch/_ops.py
index 9c5b8d7..4c0bd5d 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -243,9 +243,7 @@
return dispatch_functorch(self, args, kwargs)
if dispatch_key == torch._C.DispatchKey.Python:
- # The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc
- from torch.utils._python_dispatch import _pop_mode_temporarily
-
+ # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = _get_current_dispatch_mode()
assert (
curr_mode is not None
@@ -253,13 +251,11 @@
assert (
type(curr_mode) in self.python_key_mode_table
), f"Current active mode {curr_mode} not registered"
- handler = self.python_key_mode_table[type(curr_mode)]
- with _pop_mode_temporarily() as mode:
- return handler(mode, *args, **kwargs)
+ # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
+ return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
if functionality_key in mode_stack_per_key():
- # The place to handle DispatchKey.PreDispatch
curr_stack = mode_stack_per_key()[functionality_key]
# The check for Python in the exclude set is so we properly respect `with no_dispatch()`
# calls inside of a mode.
@@ -269,13 +265,7 @@
DispatchKey.Python
):
curr_mode = curr_stack[-1]
- pre_dispatch_modes = mode_stack_per_key().get(
- DispatchKey.PreDispatch, [] # type: ignore[attr-defined]
- )
- handler = self.python_key_mode_table[type(curr_mode)]
- if len(pre_dispatch_modes) > 0:
- with temporarily_pop_mode(pre_dispatch_modes) as mode:
- return handler(mode, *args, **kwargs)
+ return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
final_key = resolve_key(self, dispatch_key)
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index 8f14419..7dd16db 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -16,6 +16,10 @@
track_tensor_tree,
)
from torch.types import _device, _dtype
+from torch.utils._python_dispatch import (
+ _get_current_dispatch_mode,
+ _pop_mode_temporarily,
+)
rngprim_namespace = "rngprims"
@@ -188,23 +192,27 @@
return impl(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(FakeTensorMode)
- def impl_fake_tensor_mode(mode, op, *args, **kwargs):
+ def impl_fake_tensor_mode(op, *args, **kwargs):
# Check device to call the right impl
- with mode:
- return impl_backend_select(op, *args, **kwargs)
+ return impl_backend_select(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
- def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
- if mode.enable_tracing:
- out = impl_backend_select(op, *args, **kwargs)
- proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
- proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
- out_proxy = mode.tracer.create_proxy(
- "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
- )
- return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
- else:
- return run_and_save_rng_state(op, *args, **kwargs)
+ def impl_proxy_dispatch_mode(op, *args, **kwargs):
+ mode = _get_current_dispatch_mode()
+ assert mode is not None
+ with _pop_mode_temporarily() as mode:
+ if mode.enable_tracing:
+ out = impl_fake_tensor_mode(op, *args, **kwargs)
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
+ out_proxy = mode.tracer.create_proxy(
+ "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
+ )
+ return track_tensor_tree(
+ out, out_proxy, constant=None, tracer=mode.tracer
+ )
+ else:
+ return run_and_save_rng_state(op, *args, **kwargs)
return run_and_save_rng_state
@@ -239,20 +247,25 @@
return out
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
- def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
- if mode.enable_tracing:
- with disable_proxy_modes_tracing():
- out = run_with_rng_state(rng_state, op, *args, **kwargs)
- proxy_args = pytree.tree_map(
- mode.tracer.unwrap_proxy, (rng_state, op, *args)
- )
- proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
- out_proxy = mode.tracer.create_proxy(
- "call_function", run_with_rng_state, proxy_args, proxy_kwargs
- )
- return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
- else:
- return run_with_rng_state(rng_state, op, *args, **kwargs)
+ def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
+ mode = _get_current_dispatch_mode()
+ assert mode is not None
+ with _pop_mode_temporarily() as mode:
+ if mode.enable_tracing:
+ with disable_proxy_modes_tracing():
+ out = run_with_rng_state(rng_state, op, *args, **kwargs)
+ proxy_args = pytree.tree_map(
+ mode.tracer.unwrap_proxy, (rng_state, op, *args)
+ )
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
+ out_proxy = mode.tracer.create_proxy(
+ "call_function", run_with_rng_state, proxy_args, proxy_kwargs
+ )
+ return track_tensor_tree(
+ out, out_proxy, constant=None, tracer=mode.tracer
+ )
+ else:
+ return run_with_rng_state(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
@@ -263,11 +276,10 @@
return impl(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(FakeTensorMode)
- def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
+ def impl_fake_tensor_mode(rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
- with mode:
- return op(*args, **kwargs)
+ return op(*args, **kwargs)
return run_with_rng_state