blob: 1e6d156bdb274d26a3e135d355155606250ec8cd [file] [log] [blame]
#include <ATen/FuncTorchTLS.h>
namespace at { namespace functorch {
namespace {
thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr;
}
std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() {
if (kFuncTorchTLS == nullptr) {
return nullptr;
}
return kFuncTorchTLS->deepcopy();
}
void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) {
if (state == nullptr) {
kFuncTorchTLS = nullptr;
return;
}
kFuncTorchTLS = state->deepcopy();
}
std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() {
return kFuncTorchTLS;
}
}}