| #include "ATen/CUDAGenerator.h" |
| #include "ATen/Context.h" |
| #define const_generator_cast(generator) \ |
| dynamic_cast<const CUDAGenerator&>(generator) |
| CUDAGenerator::CUDAGenerator(Context * context_) |
| int num_devices, current_device; |
| cudaGetDeviceCount(&num_devices); |
| cudaGetDevice(¤t_device); |
| THCRandom_init(context->thc_state, num_devices, current_device); |
| CUDAGenerator::~CUDAGenerator() { |
| // no-op Generator state is global to the program |
| CUDAGenerator& CUDAGenerator::copy(const Generator& from) { |
| throw std::runtime_error("CUDAGenerator::copy() not implemented"); |
| CUDAGenerator& CUDAGenerator::free() { |
| THCRandom_shutdown(context->thc_state); |
| unsigned long CUDAGenerator::seed() { |
| return THCRandom_initialSeed(context->thc_state); |
| CUDAGenerator& CUDAGenerator::manualSeed(unsigned long seed) { |
| THCRandom_manualSeed(context->thc_state, seed); |