|  | #pragma once | 
|  |  | 
|  | // This header provides C++ wrappers around commonly used CUDA API functions. | 
|  | // The benefit of using C++ here is that we can raise an exception in the | 
|  | // event of an error, rather than explicitly pass around error codes.  This | 
|  | // leads to more natural APIs. | 
|  | // | 
|  | // The naming convention used here matches the naming convention of torch.cuda | 
|  |  | 
|  | #include <cuda_runtime_api.h> | 
|  |  | 
|  | #include <c10/macros/Macros.h> | 
|  | #include <c10/core/Device.h> | 
|  | #include <c10/cuda/CUDAException.h> | 
|  |  | 
|  | namespace c10 { | 
|  | namespace cuda { | 
|  |  | 
|  | inline DeviceIndex device_count() noexcept { | 
|  | int count; | 
|  | // NB: In the past, we were inconsistent about whether or not this reported | 
|  | // an error if there were driver problems are not.  Based on experience | 
|  | // interacting with users, it seems that people basically ~never want this | 
|  | // function to fail; it should just return zero if things are not working. | 
|  | // Oblige them. | 
|  | cudaError_t err = cudaGetDeviceCount(&count); | 
|  | if (err != cudaSuccess) { | 
|  | // Clear out the error state, so we don't spuriously trigger someone else. | 
|  | // (This shouldn't really matter, since we won't be running very much CUDA | 
|  | // code in this regime.) | 
|  | cudaError_t last_err = cudaGetLastError(); | 
|  | (void)last_err; | 
|  | return 0; | 
|  | } | 
|  | return static_cast<DeviceIndex>(count); | 
|  | } | 
|  |  | 
|  | inline DeviceIndex current_device() { | 
|  | int cur_device; | 
|  | C10_CUDA_CHECK(cudaGetDevice(&cur_device)); | 
|  | return static_cast<DeviceIndex>(cur_device); | 
|  | } | 
|  |  | 
|  | inline void set_device(DeviceIndex device) { | 
|  | C10_CUDA_CHECK(cudaSetDevice(static_cast<int>(device))); | 
|  | } | 
|  |  | 
|  | }} // namespace c10::cuda |