blob: f18c883a2a06a7f35750392b91be0e3c06c454a9 [file] [log] [blame]
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
#if !defined(USE_ROCM)
#include <cuda_profiler_api.h>
#else
#include <hip/hip_runtime_api.h>
#endif
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
namespace torch {
namespace cuda {
namespace shared {
#ifdef USE_ROCM
namespace {
hipError_t hipReturnSuccess() {
return hipSuccess;
}
} // namespace
#endif
void initCudartBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cudart = m.def_submodule("_cudart", "libcudart.so bindings");
// By splitting the names of these objects into two literals we prevent the
// HIP rewrite rules from changing these names when building with HIP.
#if !defined(USE_ROCM)
py::enum_<cudaOutputMode_t>(
cudart,
"cuda"
"OutputMode")
.value("KeyValuePair", cudaKeyValuePair)
.value("CSV", cudaCSV);
#endif
py::enum_<cudaError_t>(
cudart,
"cuda"
"Error")
.value("success", cudaSuccess);
cudart.def(
"cuda"
"GetErrorString",
cudaGetErrorString);
cudart.def(
"cuda"
"ProfilerStart",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStart
#endif
);
cudart.def(
"cuda"
"ProfilerStop",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStop
#endif
);
cudart.def(
"cuda"
"HostRegister",
[](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t {
return C10_CUDA_ERROR_HANDLED(
cudaHostRegister((void*)ptr, size, flags));
});
cudart.def(
"cuda"
"HostUnregister",
[](uintptr_t ptr) -> cudaError_t {
return C10_CUDA_ERROR_HANDLED(cudaHostUnregister((void*)ptr));
});
cudart.def(
"cuda"
"StreamCreate",
[](uintptr_t ptr) -> cudaError_t {
return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
});
cudart.def(
"cuda"
"StreamDestroy",
[](uintptr_t ptr) -> cudaError_t {
return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr));
});
#if !defined(USE_ROCM)
cudart.def(
"cuda"
"ProfilerInitialize",
cudaProfilerInitialize);
#endif
cudart.def(
"cuda"
"MemGetInfo",
[](int device) -> std::pair<size_t, size_t> {
c10::cuda::CUDAGuard guard(device);
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return {device_free, device_total};
});
}
} // namespace shared
} // namespace cuda
} // namespace torch