blob: 8f2cac8206d4dd111d190488c89dee4e6759c9e2 [file] [log] [blame]
#include <ATen/DLConvertor.h>
#include <ATen/Functions.h>
using namespace std;
namespace at {
DLDataType getDLDataType(const Tensor& t) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = t.element_size() * 8;
switch (t.scalar_type()) {
case ScalarType::Byte:
case ScalarType::UInt16:
case ScalarType::UInt32:
case ScalarType::UInt64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Char:
dtype.code = DLDataTypeCode::kDLInt;
break;
// NOLINTNEXTLINE(bugprone-branch-clone)
case ScalarType::Double:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Float:
dtype.code = DLDataTypeCode::kDLFloat;
break;
// NOLINTNEXTLINE(bugprone-branch-clone)
case ScalarType::Int:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Long:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Short:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Half:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Bool:
dtype.code = DLDataTypeCode::kDLBool;
break;
case ScalarType::ComplexHalf:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::ComplexFloat:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::ComplexDouble:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
case ScalarType::Float8_e5m2:
case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fn:
case ScalarType::Float8_e4m3fnuz:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
}
return dtype;
}
static DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) {
DLDevice ctx;
ctx.device_id = device_id;
switch (tensor.device().type()) {
case DeviceType::CPU:
ctx.device_type = DLDeviceType::kDLCPU;
break;
case DeviceType::CUDA:
#ifdef USE_ROCM
// ROCM, if enabled will look like cuda to PyTorch
// while everyone else should see HIP
ctx.device_type = DLDeviceType::kDLROCM;
#else
ctx.device_type = DLDeviceType::kDLCUDA;
#endif
break;
case DeviceType::OPENCL:
ctx.device_type = DLDeviceType::kDLOpenCL;
break;
case DeviceType::HIP:
ctx.device_type = DLDeviceType::kDLROCM;
break;
case DeviceType::XPU:
ctx = at::detail::getXPUHooks().getDLPackDeviceFromATenDevice(
ctx, tensor.device(), tensor.data_ptr());
break;
default:
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
}
return ctx;
}
static Device getATenDevice(const DLDevice& ctx, void* data) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return at::Device(DeviceType::CPU);
#ifndef USE_ROCM
// if we are compiled under HIP, we cannot do cuda
case DLDeviceType::kDLCUDA:
return at::Device(DeviceType::CUDA, ctx.device_id);
#endif
case DLDeviceType::kDLOpenCL:
return at::Device(DeviceType::OPENCL, ctx.device_id);
case DLDeviceType::kDLROCM:
#ifdef USE_ROCM
// this looks funny, we need to return CUDA here to masquerade
return at::Device(DeviceType::CUDA, ctx.device_id);
#else
return at::Device(DeviceType::HIP, ctx.device_id);
#endif
case DLDeviceType::kDLOneAPI:
return at::detail::getXPUHooks().getATenDeviceFromDLPackDevice(ctx, data);
default:
TORCH_CHECK(
false, "Unsupported device_type: " + c10::to_string(ctx.device_type));
}
}
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype;
TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Byte;
break;
default:
TORCH_CHECK(
false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Char;
break;
case 16:
stype = ScalarType::Short;
break;
case 32:
stype = ScalarType::Int;
break;
case 64:
stype = ScalarType::Long;
break;
default:
TORCH_CHECK(
false, "Unsupported kInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::Half;
break;
case 32:
stype = ScalarType::Float;
break;
case 64:
stype = ScalarType::Double;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::BFloat16;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLComplex:
switch (dtype.bits) {
case 32:
stype = ScalarType::ComplexHalf;
break;
case 64:
stype = ScalarType::ComplexFloat;
break;
case 128:
stype = ScalarType::ComplexDouble;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLBool:
switch (dtype.bits) {
case 8:
stype = ScalarType::Bool;
break;
default:
TORCH_CHECK(
false, "Unsupported kDLBool bits " + c10::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK(
false, "Unsupported code " + c10::to_string(dtype.code));
}
return stype;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ATenDLMTensor {
Tensor handle;
DLManagedTensor tensor;
};
static void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}
// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
auto strides = src.strides().vec();
for (int i=0; i<src.dim(); i++) {
if (shape[i] < 2) {
strides[i] = 1;
}
}
auto view = src.as_strided(shape, strides, src.storage_offset());
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = view;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
int64_t device_id = 0;
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
Tensor fromDLPack(const DLManagedTensor* src) {
auto deleter = [src](void* self) {
if (src->deleter) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
src->deleter(const_cast<DLManagedTensor*>(src));
}
};
return fromDLPack(src, std::move(deleter));
}
Tensor fromDLPack(
const DLManagedTensor* src,
std::function<void(void*)> deleter) {
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
if (!src->dl_tensor.strides) {
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype),
{device});
}
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype),
{ device });
}
} // namespace at