fix
diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc
index ce36a5f..fdc439d 100644
--- a/tensorflow/c/eager/dlpack.cc
+++ b/tensorflow/c/eager/dlpack.cc
@@ -27,12 +27,17 @@
namespace {
+// Managing context for the DLManagedTensor, will manage the lifetime of
+// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
+// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
- TensorReference* reference;
+ TensorReference reference;
std::vector<int64_t> shape;
+ std::vector<int64_t> strides;
DLManagedTensor tensor;
- TfDlManagedTensorCtx()
+ TfDlManagedTensorCtx(const TensorReference& ref)
+ : reference(ref), shape(), tensor() {}
};
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
@@ -61,8 +66,7 @@
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
- owner->reference->Unref();
- delete owner->reference;
+ owner->reference.Unref();
delete owner;
}
@@ -129,31 +133,41 @@
TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
- auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx;
+ TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
- TensorReference* tensor_ref =
- new TensorReference(*tensor); // This will call buf_->Ref()
+ auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
- tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
- tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
- tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
+
+ DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
+ dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
+ dlm_tensor->deleter = &DLManagedTensorDeleter;
+ dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims();
- tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
- tf_dlm_tensor_ctx->tensor.dl_tensor.data =
- TFE_TensorHandleDevicePointer(h, status);
- tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
+ dlm_tensor->dl_tensor.ndim = ndim;
+ dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
+ dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
+ std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
+ stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
+ for (int i = ndim - 2; i >= 0; --i) {
+ (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
+ }
- tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
+ dlm_tensor->dl_tensor.shape =
reinterpret_cast<std::int64_t*>(shape_arr->data());
- tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
- nullptr; // nullptr indicates tensor is compact and row-majored.
- tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
+ // There are two ways to represent compact row-major data
+ // 1) nullptr indicates tensor is compact and row-majored.
+ // 2) fill in the strides array as the real case for compact row-major data
+ // Here we choose option 2, since some framework didn't handle the strides
+ // argument properly
+ dlm_tensor->dl_tensor.strides =
+ reinterpret_cast<std::int64_t*>(stride_arr->data());
+ dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;
}
@@ -250,6 +264,15 @@
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
+bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
+ int ndim) {
+ for (int i = ndim - 2; i >= 0; --i) {
+ if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
+ return false;
+ };
+ }
+ return true;
+}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
@@ -268,23 +291,32 @@
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
-
+ DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
- DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
+ DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
- TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
- int num_dims = dlmt->dl_tensor.ndim;
- const int64_t* dims = dlmt->dl_tensor.shape;
- void* data = dlmt->dl_tensor.data;
+ TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status);
+ int num_dims = dl_tensor->ndim;
+ const int64_t* dims = dl_tensor->shape;
+ void* data = dl_tensor->data;
- size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8;
+ size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
+
+ if ((dl_tensor->strides != nullptr) &&
+ !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
+ num_dims)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid strides array from DLPack");
+ return nullptr;
+ }
+
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);