blob: 245f221029d6ba5afdd161c5ecbd61d7e15d7862 [file] [log] [blame]
#include "Handles.h"
#include <unordered_map>
#include <mutex>
#include "Exceptions.h"
namespace torch { namespace cudnn {
namespace {
struct Handle {
cudnnHandle_t handle;
Handle() : handle(NULL) {
CHECK(cudnnCreate(&handle));
}
~Handle() {
if (handle) {
cudnnDestroy(handle);
}
}
};
std::mutex mutex;
std::unordered_map<int, Handle> handles;
} // namespace
cudnnHandle_t getCudnnHandle()
{
int device;
CUDA_CHECK(cudaGetDevice(&device));
std::lock_guard<std::mutex> guard(mutex);
return handles[device].handle;
}
}} // namespace torch::cudnn