| #pragma once |
| |
| #include <memory> |
| #include <mutex> |
| #include "ATen/Generator.h" |
| #include "ATen/Type.h" |
| #include "ATen/Utils.h" |
| |
| class THCState; |
| |
| namespace at { |
| |
| class Context { |
| public: |
| Context(); |
| Type & getType(Backend p, ScalarType s) { |
| initCUDAIfNeeded(p); |
| auto & type = type_registry[static_cast<int>(p)][static_cast<int>(s)]; |
| if(!type) |
| runtime_error("%s%sType is not enabled.",toString(p),toString(s)); |
| return *type; |
| } |
| Generator & defaultGenerator(Backend p) { |
| initCUDAIfNeeded(p); |
| auto & generator = generator_registry[static_cast<int>(p)]; |
| if(!generator) |
| runtime_error("%s backend type not enabled.",toString(p)); |
| return *generator; |
| } |
| bool hasCUDA() const; |
| // defined in header so that getType has ability to inline |
| // call_once check. getType is called fairly frequently |
| THCState* lazyInitCUDA() { |
| std::call_once(thc_init,[&] { |
| doInitCUDA(); |
| }); |
| return thc_state; |
| } |
| ~Context(); |
| std::unique_ptr<Generator> |
| generator_registry[static_cast<int>(Backend::NumOptions)]; |
| std::unique_ptr<Type> type_registry |
| [static_cast<int>(Backend::NumOptions)] |
| [static_cast<int>(ScalarType::NumOptions)]; |
| THCState * thc_state; |
| private: |
| void initCUDAIfNeeded(Backend p) { |
| if(p == Backend::CUDA) |
| lazyInitCUDA(); |
| } |
| void doInitCUDA(); |
| std::once_flag thc_init; |
| }; |
| |
| Context & globalContext(); |
| |
| static inline void init() { |
| globalContext(); |
| } |
| |
| static inline Type& getType(Backend p, ScalarType s) { |
| return globalContext().getType(p,s); |
| } |
| |
| static inline Type& CPU(ScalarType s) { |
| return getType(Backend::CPU, s); |
| } |
| |
| static inline Type& CUDA(ScalarType s) { |
| return getType(Backend::CUDA, s); |
| } |
| |
| static inline bool hasCUDA() { |
| return globalContext().hasCUDA(); |
| } |
| |
| } |