blob: f4cc18cab631fea8a6188257554360cdef6e98d8 [file] [log] [blame]
#include "CUDAUtils.hpp"
#include "./private/CUDAUtils.hpp"
namespace c10d {
CUDADevice::CUDADevice(int device) {
setDevice(device);
}
CUDADevice::~CUDADevice() {
setDevice(originalDevice_);
}
void CUDADevice::setDevice(int device) {
if (device >= 0) {
if (originalDevice_ == -1) {
C10D_CUDA_CHECK(cudaGetDevice(&originalDevice_));
if (device != originalDevice_) {
C10D_CUDA_CHECK(cudaSetDevice(device));
}
} else {
C10D_CUDA_CHECK(cudaSetDevice(device));
}
}
}
CUDAEvent CUDAEvent::create() {
CUDAEvent event;
C10D_CUDA_CHECK(cudaEventCreate(&event.event_));
return event;
}
CUDAEvent::~CUDAEvent() {
if (event_ != nullptr) {
C10D_CUDA_CHECK(cudaEventDestroy(event_));
}
}
CUDAStream CUDAStream::create() {
CUDAStream stream;
stream.stream_ = THCStream_new(cudaStreamNonBlocking);
return stream;
}
CUDAStream::~CUDAStream() {
if (stream_ != nullptr) {
THCStream_free(stream_);
stream_ = nullptr;
}
}
cudaStream_t CUDAStream::getStream() const {
return THCStream_stream(stream_);
}
} // namespace c10d