blob: 9a5f079ddf4d7b7f6a8a79617335870afe53bc8e [file] [log] [blame]
#include <ATen/ThreadLocalState.h>
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
#include <ATen/core/grad_mode.h>
#endif
#include <ATen/record_function.h>
#include <ATen/SavedTensorHooks.h>
namespace at {
ThreadLocalState::ThreadLocalState()
: dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
debug_info_(c10::ThreadLocalDebugInfo::current()),
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
autograd_tls_(c10::AutogradState::get_tls_state()),
python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
rf_tls_ = at::get_record_function_tls_();
saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state();
torch_dispatch_mode_state_ = c10::impl::TorchDispatchModeTLS::get_state();
}
void ThreadLocalState::set_grad_mode(bool enabled) {
autograd_tls_.set_grad_mode(enabled);
}
/* static */
void ThreadLocalState::setThreadLocalState(
const ThreadLocalState& state) {
// Note that setting the InferenceMode TLS in this function is ONLY ok because we always
// restore the dispatch key set TLS at the same time.
c10::AutogradState::set_tls_state(state.autograd_tls_);
c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
at::set_record_function_tls_(state.rf_tls_);
at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
functorch::setFuncTorchTLS(state.functorch_tls_);
}
} // namespace at