Preserve dispatch state across function tracing (#122073)
If we throw an exception in the "wrong" place we can end up with the dispatch state being in a weird state which can cause all future dispatching to fail. Preserve and restore it as part of `preserve_global_state` so we know it's sane after that.
Also fake_tensor's in_kernel_invocation_manager() was leaving a bit set in the dispatcher (DispatchKey.Dense) which affected follow-on code. Fixed that to reset after as well.
Repro:
before:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
======== 1 passed, 6173 deselected in 5.21s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
========= 1 skipped, 6172 deselected, 1 error in 5.29s =========
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes on its own but failed when including the skipped test_export.py tests)
after:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 6173 deselected in 5.42s =====================
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 1 skipped, 6172 deselected in 7.30s ======================
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes in both runs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122073
Approved by: https://github.com/zou3519
diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h
index ef4acbb..176d0a6 100644
--- a/c10/core/impl/LocalDispatchKeySet.h
+++ b/c10/core/impl/LocalDispatchKeySet.h
@@ -117,14 +117,16 @@
struct C10_API ForceDispatchKeyGuard {
public:
+ ForceDispatchKeyGuard()
+ : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
- : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
+ : ForceDispatchKeyGuard() {
c10::impl::_force_tls_local_dispatch_key_set(key_set);
}
ForceDispatchKeyGuard(
c10::DispatchKeySet include,
c10::DispatchKeySet exclude)
- : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
+ : ForceDispatchKeyGuard() {
auto updated_set = saved_keyset_;
updated_set.included_ = include;
updated_set.excluded_ = exclude;
diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py
index ce7a233..1380de4 100644
--- a/test/inductor/test_cpu_cpp_wrapper.py
+++ b/test/inductor/test_cpu_cpp_wrapper.py
@@ -133,16 +133,21 @@
tests.setUpClass()
tests.setUp()
try:
- _, code = test_torchinductor.run_and_get_cpp_code(
- func, *func_inputs if func_inputs else []
- )
- self.assertEqual("CppWrapperCodeCache" in code, True)
- self.assertTrue(
- all(
- code.count(string) == code_string_count[string]
- for string in code_string_count
+ with torch._C._PreserveDispatchKeyGuard():
+ torch._C._dispatch_tls_set_dispatch_key_included(
+ torch._C.DispatchKey.Dense, True
)
- )
+
+ _, code = test_torchinductor.run_and_get_cpp_code(
+ func, *func_inputs if func_inputs else []
+ )
+ self.assertEqual("CppWrapperCodeCache" in code, True)
+ self.assertTrue(
+ all(
+ code.count(string) == code_string_count[string]
+ for string in code_string_count
+ )
+ )
finally:
tests.tearDown()
tests.tearDownClass()
diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py
index 78f9bee..b662e24 100644
--- a/test/inductor/test_cuda_cpp_wrapper.py
+++ b/test/inductor/test_cuda_cpp_wrapper.py
@@ -3,6 +3,7 @@
import unittest
from typing import NamedTuple
+import torch
from torch._inductor import config
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.common_device_type import (
@@ -142,16 +143,21 @@
tests.setUpClass()
tests.setUp()
try:
- _, code = test_torchinductor.run_and_get_cpp_code(
- func, *func_inputs if func_inputs else []
- )
- self.assertEqual("CppWrapperCodeCache" in code, True)
- self.assertTrue(
- all(
- code.count(string) == code_string_count[string]
- for string in code_string_count
+ with torch._C._PreserveDispatchKeyGuard():
+ torch._C._dispatch_tls_set_dispatch_key_included(
+ torch._C.DispatchKey.Dense, True
)
- )
+
+ _, code = test_torchinductor.run_and_get_cpp_code(
+ func, *func_inputs if func_inputs else []
+ )
+ self.assertEqual("CppWrapperCodeCache" in code, True)
+ self.assertTrue(
+ all(
+ code.count(string) == code_string_count[string]
+ for string in code_string_count
+ )
+ )
finally:
tests.tearDown()
tests.tearDownClass()
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 42d0a52..b39f0f4 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1532,6 +1532,11 @@
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
+class _PreserveDispatchKeyGuard:
+ def __init__(self): ...
+ def __enter__(self): ...
+ def __exit__(self, exc_type, exc_value, traceback): ...
+
class _AutoDispatchBelowAutograd:
def __init__(self): ...
def __enter__(self): ...
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 93ff6d9..41a37cb 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -153,35 +153,42 @@
def _fn(*args, **kwargs):
guards = GlobalStateGuard()
prior_grad_mode = torch.is_grad_enabled()
- prior_inference_mode = torch.is_inference_mode_enabled()
- prior_deterministic = torch.are_deterministic_algorithms_enabled()
- prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
- py_rng_state = random.getstate()
- torch_rng_state = torch.random.get_rng_state()
- if torch.cuda.is_available():
- cuda_rng_state = torch.cuda.get_rng_state()
- allow_tf32 = torch._C._get_cublas_allow_tf32()
- prior_fwd_from_src = torch.fx.graph_module._forward_from_src
- torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
- cleanup = setup_compile_debug()
- try:
- return fn(*args, **kwargs)
- finally:
- cleanup.close()
- torch._C._set_grad_enabled(prior_grad_mode)
- torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
- torch.use_deterministic_algorithms(
- prior_deterministic, warn_only=prior_warn_only
- )
- random.setstate(py_rng_state)
- torch.random.set_rng_state(torch_rng_state)
+ # Just in case we get left in a bad dispatch state we want to restore
+ # it. This can happen because the dispatch bits aren't a true
+ # stack/counter - so we can't just increment/decrement them as we enter
+ # and leave.
+ with torch._C._PreserveDispatchKeyGuard():
+ prior_inference_mode = torch.is_inference_mode_enabled()
+ prior_deterministic = torch.are_deterministic_algorithms_enabled()
+ prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
+ py_rng_state = random.getstate()
+ torch_rng_state = torch.random.get_rng_state()
if torch.cuda.is_available():
- torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
- torch._C._set_cublas_allow_tf32(allow_tf32)
- torch.fx.graph_module._forward_from_src = prior_fwd_from_src
- assert (
- guards.check()
- ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
+ cuda_rng_state = torch.cuda.get_rng_state()
+ allow_tf32 = torch._C._get_cublas_allow_tf32()
+ prior_fwd_from_src = torch.fx.graph_module._forward_from_src
+ torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
+ cleanup = setup_compile_debug()
+ try:
+ return fn(*args, **kwargs)
+ finally:
+ cleanup.close()
+ torch._C._set_grad_enabled(prior_grad_mode)
+ torch.torch.autograd.grad_mode._enter_inference_mode(
+ prior_inference_mode
+ )
+ torch.use_deterministic_algorithms(
+ prior_deterministic, warn_only=prior_warn_only
+ )
+ random.setstate(py_rng_state)
+ torch.random.set_rng_state(torch_rng_state)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
+ torch._C._set_cublas_allow_tf32(allow_tf32)
+ torch.fx.graph_module._forward_from_src = prior_fwd_from_src
+ assert (
+ guards.check()
+ ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
return _fn
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 4bb5232..4cf1485 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -336,15 +336,17 @@
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
- guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
- fake_mode.in_kernel_invocation = True
- torch._C._set_meta_in_tls_dispatch_include(True)
- try:
- yield
- finally:
- fake_mode.in_kernel_invocation = prev_in_kernel
- torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
- del guard
+ with torch._C._DisableTorchDispatch():
+ fake_mode.in_kernel_invocation = True
+ # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave
+ # `Dense` turned on (because it's implied by `Meta`)
+ with torch._C._PreserveDispatchKeyGuard():
+ torch._C._set_meta_in_tls_dispatch_include(True)
+ try:
+ yield
+ finally:
+ fake_mode.in_kernel_invocation = prev_in_kernel
+ # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
# Return if the function allows Python numbers to bind to Tensors
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 3d1c6d0..2b905a9 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -721,6 +721,8 @@
c10::impl::ForceDispatchKeyGuard,
c10::DispatchKeySet,
c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
+ py_context_manager<c10::impl::ForceDispatchKeyGuard>(
+ m, "_PreserveDispatchKeyGuard");
py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
m, "_IncludeDispatchKeyGuard");
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(