blob: e1b079752b7674b88fc9f756fa64546e18cb7409 [file] [log] [blame]
#include "Cuda.hpp"
#include <unordered_map>
#ifdef USE_CUDA
THCState** _THDCudaState;
void THDSetCudaStatePtr(THCState **state) {
_THDCudaState = state;
}
static int nextStreamId = 1; // 0 for the default stream
static std::unordered_map<cudaStream_t, int> streamIdMap;
void THDRegisterCudaStream(cudaStream_t stream) {
streamIdMap.emplace(stream, nextStreamId++);
}
int THDGetStreamId(cudaStream_t stream) {
if (!stream) return 0;
auto it = streamIdMap.find(stream);
if (it == streamIdMap.end()) {
throw std::runtime_error("using a stream that's hasn't been registered in THD");
}
return it->second;
}
#endif