blob: b53043ca841475c2cac86b000dfa78c1cef579fd [file] [log] [blame]
#include <ATen/PythonModeTLS.h>
namespace at { namespace impl {
thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
pythonModeState = state;
if (state) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
} else {
PythonModeTLS::reset_state();
}
}
const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
return pythonModeState;
}
void PythonModeTLS::reset_state() {
pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
}
} // namespace impl
} // namespace at