blob: a4a30937c22fd9754cf51d396af340ee7e05be60 [file] [log] [blame]
#include "Context.h"
#include <thread>
#include <mutex>
#include <sstream>
#ifdef AT_CUDA_ENABLED
#include "THC/THC.h"
#include "ATen/CUDAGenerator.h"
#endif
#include "ATen/CPUGenerator.h"
namespace at {
static inline void errorHandler(const char * msg, void * data) {
throw std::runtime_error(msg);
}
static inline void argErrorHandler(int arg, const char * msg, void * data) {
std::stringstream new_error;
new_error << "invalid argument " << arg << ": " << msg;
throw std::runtime_error(new_error.str());
}
Context::Context()
: thc_state(nullptr) {
THSetDefaultErrorHandler(errorHandler,nullptr);
THSetDefaultArgErrorHandler(argErrorHandler,nullptr);
generator_registry[static_cast<int>(Backend::CPU)]
.reset(new CPUGenerator(this));
Type::registerAll(this);
}
void Context::doInitCUDA() {
#ifdef AT_CUDA_ENABLED
thc_state = THCState_alloc();
THCState_setDeviceAllocator(thc_state, THCCachingAllocator_get());
thc_state->cudaHostAllocator = &THCCachingHostAllocator;
THCudaInit(thc_state);
generator_registry[static_cast<int>(Backend::CUDA)]
.reset(new CUDAGenerator(this));
#endif
}
Context::~Context() {
#ifdef AT_CUDA_ENABLED
if(thc_state)
THCState_free(thc_state);
#endif
}
Context & globalContext() {
static Context globalContext_;
return globalContext_;
}
bool Context::hasCUDA() const {
#ifdef AT_CUDA_ENABLED
return true;
#else
return false;
#endif
}
}