blob: dd82f94fdfde87e8ffa1d8460b4bcefaf942ef3d [file] [log] [blame]
#pragma once
#include <memory>
#include <mutex>
#include "ATen/Generator.h"
#include "ATen/Type.h"
#include "ATen/Utils.h"
class THCState;
namespace at {
class Context {
public:
Context();
Type & getType(Backend p, ScalarType s) {
initCUDAIfNeeded(p);
auto & type = type_registry[static_cast<int>(p)][static_cast<int>(s)];
if(!type)
runtime_error("%s%sType is not enabled.",toString(p),toString(s));
return *type;
}
Generator & defaultGenerator(Backend p) {
initCUDAIfNeeded(p);
auto & generator = generator_registry[static_cast<int>(p)];
if(!generator)
runtime_error("%s backend type not enabled.",toString(p));
return *generator;
}
bool hasCUDA() const;
// defined in header so that getType has ability to inline
// call_once check. getType is called fairly frequently
THCState* lazyInitCUDA() {
std::call_once(thc_init,[&] {
doInitCUDA();
});
return thc_state;
}
~Context();
std::unique_ptr<Generator>
generator_registry[static_cast<int>(Backend::NumOptions)];
std::unique_ptr<Type> type_registry
[static_cast<int>(Backend::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)];
THCState * thc_state;
private:
void initCUDAIfNeeded(Backend p) {
if(p == Backend::CUDA)
lazyInitCUDA();
}
void doInitCUDA();
std::once_flag thc_init;
};
Context & globalContext();
static inline void init() {
globalContext();
}
static inline Type& getType(Backend p, ScalarType s) {
return globalContext().getType(p,s);
}
static inline Type& CPU(ScalarType s) {
return getType(Backend::CPU, s);
}
static inline Type& CUDA(ScalarType s) {
return getType(Backend::CUDA, s);
}
static inline bool hasCUDA() {
return globalContext().hasCUDA();
}
}