blob: 8338598e190e88f408e671a7ccbfa841f2379643 [file] [log] [blame]
#include <torch/csrc/utils/pybind.h>
#include <cuda.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#else
#include <hip/hip_runtime_api.h>
#endif
namespace torch { namespace cuda { namespace shared {
void initCudartBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cudart = m.def_submodule("_cudart", "libcudart.so bindings");
#ifndef __HIP_PLATFORM_HCC__
py::enum_<cudaOutputMode_t>(cudart, "cudaOutputMode")
.value("KeyValuePair", cudaKeyValuePair)
.value("CSV", cudaCSV);
#endif
py::enum_<cudaError_t>(cudart, "cudaError")
.value("success", cudaSuccess);
cudart.def("cudaGetErrorString", cudaGetErrorString);
cudart.def("cudaProfilerStart", cudaProfilerStart);
cudart.def("cudaProfilerStop", cudaProfilerStop);
cudart.def("cudaHostRegister", cudaHostRegister);
#ifndef __HIP_PLATFORM_HCC__
cudart.def("cudaProfilerInitialize", cudaProfilerInitialize);
#endif
}
} // namespace shared
} // namespace cuda
} // namespace torch