blob: ddd30b473a457e4f9810d2c6d148dce1b8c014af [file] [log] [blame]
#include <THD/base/Cuda.hpp>
#include <unordered_map>
#ifdef USE_CUDA
THCState** _THDCudaState;
void THDSetCudaStatePtr(THCState** state) {
_THDCudaState = state;
}
static int nextStreamId = 1; // 0 for the default stream
static std::unordered_map<cudaStream_t, int> streamIdMap;
void THDRegisterCudaStream(cudaStream_t stream) {
streamIdMap.emplace(stream, nextStreamId++);
}
int THDGetStreamId(cudaStream_t stream) {
if (!stream)
return 0;
auto it = streamIdMap.find(stream);
if (it == streamIdMap.end()) {
throw std::runtime_error(
"using a stream that's hasn't been registered in THD");
}
return it->second;
}
#endif