blob: 2cdd4e044f38ec1fde3f5433b6d64dc22c1024b4 [file] [log] [blame]
#include <ATen/DLConvertor.h>
#include <ATen/Functions.h>
#include <iostream>
#include <sstream>
using namespace std;
namespace at {
static DLDataType getDLDataType(const Tensor& t) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = t.element_size() * 8;
switch (t.scalar_type()) {
case ScalarType::Byte:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Char:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Double:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Float:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Int:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Long:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Short:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Half:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Bool:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::BFloat16:
throw std::logic_error("BFloat16 is not supported by dlpack");
break;
case ScalarType::QInt8:
throw std::logic_error("QInt8 is not supported by dlpack");
break;
case ScalarType::QUInt8:
throw std::logic_error("QUInt8 is not supported by dlpack");
break;
case ScalarType::QInt32:
throw std::logic_error("QInt32 is not supported by dlpack");
break;
case ScalarType::ComplexHalf:
throw std::logic_error("ComplexHalf is not supported by dlpack");
case ScalarType::ComplexFloat:
throw std::logic_error("ComplexFloat is not supported by dlpack");
case ScalarType::ComplexDouble:
throw std::logic_error("ComplexDouble is not supported by dlpack");
case ScalarType::Undefined:
throw std::logic_error("Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
throw std::logic_error("NumOptions is not a valid ScalarType");
}
return dtype;
}
static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
DLContext ctx;
ctx.device_id = device_id;
if (tensor.is_cuda()) {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
ctx.device_type = DLDeviceType::kDLCPU;
}
return ctx;
}
static Device getATenDevice(const DLContext& ctx) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return at::Device(DeviceType::CPU);
case DLDeviceType::kDLGPU:
return at::Device(DeviceType::CUDA, ctx.device_id);
case DLDeviceType::kDLOpenCL:
return at::Device(DeviceType::OPENCL, ctx.device_id);
case DLDeviceType::kDLROCM:
return at::Device(DeviceType::HIP, ctx.device_id);
default:
throw std::logic_error(
"Unsupported device_type: " + std::to_string(ctx.device_type));
}
}
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype;
if (dtype.lanes != 1)
throw std::logic_error("ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Byte;
break;
default:
throw std::logic_error(
"Unsupported kUInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Char;
break;
case 16:
stype = ScalarType::Short;
break;
case 32:
stype = ScalarType::Int;
break;
case 64:
stype = ScalarType::Long;
break;
default:
throw std::logic_error(
"Unsupported kInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::Half;
break;
case 32:
stype = ScalarType::Float;
break;
case 64:
stype = ScalarType::Double;
break;
default:
throw std::logic_error(
"Unsupported kFloat bits " + std::to_string(dtype.bits));
}
break;
default:
throw std::logic_error("Unsupported code " + std::to_string(dtype.code));
}
return stype;
}
struct ATenDLMTensor {
Tensor handle;
DLManagedTensor tensor;
};
void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}
// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = src;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
int64_t device_id = 0;
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape =
const_cast<int64_t*>(src.sizes().data());
atDLMTensor->tensor.dl_tensor.strides =
const_cast<int64_t*>(src.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
Tensor fromDLPack(const DLManagedTensor* src) {
Device device = getATenDevice(src->dl_tensor.ctx);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
auto deleter = [src](void* self) {
src->deleter(const_cast<DLManagedTensor*>(src));
};
if (!src->dl_tensor.strides) {
return at::from_blob(src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype));
}
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype));
}
} // namespace at