blob: b95b86b3374f9601035c18340a1f9204d86e05d9 [file] [log] [blame]
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_CUFILE)
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cufile.h>
namespace {
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
// filesystem error and a negative CUfileOpError enum value otherwise).
template <
class T,
typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
status = std::abs(status);
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
: std::string(std::strerror(errno));
}
// To get error message for Buf/Handle registeration APIs that return
// CUfileError_t
template <
class T,
typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
if (IS_CUDA_ERR(status))
errStr.append(".").append(
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
return errStr;
}
} // namespace
void gds_load_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Read the binary file
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
}
void gds_save_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Write device memory contents to the file
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
}
void gds_register_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufRegister failed: ",
cuGDSFileGetErrorString(status));
return;
}
void gds_deregister_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
CUfileError_t status = cuFileBufDeregister(dataPtr);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufDeregister failed: ",
cuGDSFileGetErrorString(status));
return;
}
int64_t gds_register_handle(int fd) {
CUfileDescr_t cf_descr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUfileHandle_t cf_handle;
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
cf_descr.handle.fd = fd;
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
if (status.err != CU_FILE_SUCCESS) {
TORCH_CHECK(
false,
"cuFileHandleRegister failed: ",
cuGDSFileGetErrorString(status));
}
// Returning cuFileHandle_t as int64_t
return reinterpret_cast<int64_t>(cf_handle);
}
void gds_deregister_handle(int64_t handle) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
cuFileHandleDeregister(cf_handle);
}
#endif
namespace torch::cuda::shared {
void initGdsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#if defined(USE_CUFILE)
m.def("_gds_register_handle", &gds_register_handle);
m.def("_gds_deregister_handle", &gds_deregister_handle);
m.def("_gds_register_buffer", &gds_register_buffer);
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
m.def("_gds_load_storage", &gds_load_storage);
m.def("_gds_save_storage", &gds_save_storage);
#endif
}
} // namespace torch::cuda::shared