blob: fdc439da4b58e3cd3ddc43bf052842ae49a87d62 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // TF:dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
TfDlManagedTensorCtx(const TensorReference& ref)
: reference(ref), shape(), tensor() {}
};
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
};
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = -1;
if (parsed_name.has_id) {
device_id = parsed_name.id;
} // Question: Is it possible that it doens't have id?
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape =
reinterpret_cast<std::int64_t*>(shape_arr->data());
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data
// Here we choose option 2, since some framework didn't handle the strides
// argument properly
dlm_tensor->dl_tensor.strides =
reinterpret_cast<std::int64_t*>(stride_arr->data());
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;
}
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
};
}
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_Status* status) {
TF_DataType tf_dtype;
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
tf_dtype = TF_DataType::TF_UINT8;
break;
case 16:
tf_dtype = TF_DataType::TF_UINT16;
break;
case 32:
tf_dtype = TF_DataType::TF_UINT32;
break;
case 64:
tf_dtype = TF_DataType::TF_UINT64;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported UInt bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
tf_dtype = TF_DataType::TF_INT8;
break;
case 16:
tf_dtype = TF_DataType::TF_INT16;
break;
case 32:
tf_dtype = TF_DataType::TF_INT32;
break;
case 64:
tf_dtype = TF_DataType::TF_INT64;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Int bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
tf_dtype = TF_DataType::TF_HALF;
break;
case 32:
tf_dtype = TF_DataType::TF_FLOAT;
break;
case 64:
tf_dtype = TF_DataType::TF_DOUBLE;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Float bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
tf_dtype = TF_DataType::TF_BFLOAT16;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Type Codes: ", dtype.code);
}
return tf_dtype;
}
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
};
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
DLManagedTensor* tfdlmtensor = TFEHandleToTfDlManagedTensorCtx(h, status);
return static_cast<void*>(tfdlmtensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status);
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if ((dl_tensor->strides != nullptr) &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
};
} // namespace tensorflow