blob: 6f5e46f18ed15f2410094cd8de91e91755b3121e [file] [log] [blame]
#pragma once
#include <cuda.h>
#define NVML_NO_UNVERSIONED_FUNC_DEFS
#include <nvml.h>
#define C10_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
CUresult get_error_str_err C10_UNUSED = \
c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: unknown error"); \
} else { \
AT_ERROR("CUDA driver error: ", err_str); \
} \
} \
} while (0)
#define C10_LIBCUDA_DRIVER_API(_) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
_(cuGetErrorString)
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10 {
namespace cuda {
struct DriverAPI {
#define CREATE_MEMBER(name) decltype(&name) name##_;
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
C10_NVML_DRIVER_API(CREATE_MEMBER)
#undef CREATE_MEMBER
static DriverAPI* get();
static void* get_nvml_handle();
};
} // namespace cuda
} // namespace c10