fix
diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc
index ce36a5f..fdc439d 100644
--- a/tensorflow/c/eager/dlpack.cc
+++ b/tensorflow/c/eager/dlpack.cc
@@ -27,12 +27,17 @@
 
 namespace {
 
+// Managing context for the DLManagedTensor, will manage the lifetime of
+// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
+// original framework of destruction, and this context will be deleted also.
 struct TfDlManagedTensorCtx {
-  TensorReference* reference;
+  TensorReference reference;
   std::vector<int64_t> shape;
+  std::vector<int64_t> strides;
   DLManagedTensor tensor;
 
-  TfDlManagedTensorCtx()
+  TfDlManagedTensorCtx(const TensorReference& ref)
+      : reference(ref), shape(), tensor() {}
 };
 
 const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
@@ -61,8 +66,7 @@
 void DLManagedTensorDeleter(DLManagedTensor* arg) {
   TfDlManagedTensorCtx* owner =
       static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
-  owner->reference->Unref();
-  delete owner->reference;
+  owner->reference.Unref();
   delete owner;
 }
 
@@ -129,31 +133,41 @@
                                                  TF_Status* status) {
   const Tensor* tensor = GetTensorFromHandle(h, status);
   TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
-  auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx;
+  TensorReference tensor_ref(*tensor);  // This will call buf_->Ref()
 
-  TensorReference* tensor_ref =
-      new TensorReference(*tensor);  // This will call buf_->Ref()
+  auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
   tf_dlm_tensor_ctx->reference = tensor_ref;
-  tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
-  tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
-  tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
+
+  DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
+  dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
+  dlm_tensor->deleter = &DLManagedTensorDeleter;
+  dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
   int ndim = tensor->dims();
-  tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
-  tf_dlm_tensor_ctx->tensor.dl_tensor.data =
-      TFE_TensorHandleDevicePointer(h, status);
-  tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
+  dlm_tensor->dl_tensor.ndim = ndim;
+  dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
+  dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
 
   std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
+  std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
   shape_arr->resize(ndim);
+  stride_arr->resize(ndim, 1);
   for (int i = 0; i < ndim; i++) {
     (*shape_arr)[i] = tensor->dim_size(i);
   }
+  for (int i = ndim - 2; i >= 0; --i) {
+    (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
+  }
 
-  tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
+  dlm_tensor->dl_tensor.shape =
       reinterpret_cast<std::int64_t*>(shape_arr->data());
-  tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
-      nullptr;  // nullptr indicates tensor is compact and row-majored.
-  tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
+  // There are two ways to represent compact row-major data
+  // 1) nullptr indicates tensor is compact and row-majored.
+  // 2) fill in the strides array as the real case for compact row-major data
+  // Here we choose option 2, since some framework didn't handle the strides
+  // argument properly
+  dlm_tensor->dl_tensor.strides =
+      reinterpret_cast<std::int64_t*>(stride_arr->data());
+  dlm_tensor->dl_tensor.byte_offset =
       0;  // TF doesn't handle the strides and byte_offsets here
   return &tf_dlm_tensor_ctx->tensor;
 }
@@ -250,6 +264,15 @@
   dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
 }
 
+bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
+                                      int ndim) {
+  for (int i = ndim - 2; i >= 0; --i) {
+    if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
+      return false;
+    };
+  }
+  return true;
+}
 }  // namespace
 
 void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
@@ -268,23 +291,32 @@
   TFE_ContextOptions* opts = TFE_NewContextOptions();
   TFE_Context* ctx = TFE_NewContext(opts, status);
   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
-
+  DLTensor* dl_tensor = &dlmt->dl_tensor;
   absl::optional<std::string> device_name =
-      DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
+      DeviceNameFromDlContext(dl_tensor->ctx, status);
   if (!device_name.has_value()) {
     status->status =
         tensorflow::errors::InvalidArgument("Unsupported Device Type");
     return nullptr;
   }
-  TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
-  int num_dims = dlmt->dl_tensor.ndim;
-  const int64_t* dims = dlmt->dl_tensor.shape;
-  void* data = dlmt->dl_tensor.data;
+  TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status);
+  int num_dims = dl_tensor->ndim;
+  const int64_t* dims = dl_tensor->shape;
+  void* data = dl_tensor->data;
 
-  size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8;
+  size_t total_bytes = dl_tensor->dtype.bits / 8;
   for (int i = 0; i < num_dims; i++) {
     total_bytes *= dims[i];
   }
+
+  if ((dl_tensor->strides != nullptr) &&
+      !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
+                                        num_dims)) {
+    status->status = tensorflow::errors::InvalidArgument(
+        "Invalid strides array from DLPack");
+    return nullptr;
+  }
+
   TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
       ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
       total_bytes, &DeallocatorWrapperFunc, &dlmt, status);