Support passing native tflite resource variables with dataset ops.

Add runtime code to convert between TF Lite native resource tensor and TF resource tensor.

PiperOrigin-RevId: 410397929
Change-Id: If0c486c6c064daceaf30442fd2e7495fa1473e48
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 6f14fec..81e3af8 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -64,6 +64,7 @@
         ":util",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite:string_util",
+        "//tensorflow/lite/experimental/resource",
     ] + select({
         "//tensorflow:android": [
             "//tensorflow/core:portable_tensorflow_lib_lite",
@@ -324,6 +325,8 @@
     deps = [
         "//tensorflow/lite/c:common",
         "//tensorflow/lite:kernel_api",
+        "@com_google_absl//absl/strings:str_format",
+        "//tensorflow/lite/kernels/internal:tensor",
     ] + select({
         "//tensorflow:android": [
             "//tensorflow/core:portable_tensorflow_lib_lite",
@@ -345,7 +348,9 @@
     srcs = ["util_test.cc"],
     deps = [
         ":util",
+        "//tensorflow/core:framework",
         "//tensorflow/lite:string",
+        "//tensorflow/lite:util",
         "//tensorflow/lite/testing:util",
         "@com_google_googletest//:gtest_main",
     ],
@@ -421,6 +426,9 @@
         ":buffer_map_util",
         ":subgraph_resource",
         ":util",
+        "@com_google_absl//absl/strings",
+        "//tensorflow/lite/kernels/internal:tensor",
+        "@com_google_absl//absl/strings:str_format",
         "//tensorflow/lite:cc_api",
         "//tensorflow/lite:string_util",
         "//tensorflow/lite/c:c_api_types",
diff --git a/tensorflow/lite/delegates/flex/buffer_map_test.cc b/tensorflow/lite/delegates/flex/buffer_map_test.cc
index 6b0d5ef..3e7c087 100644
--- a/tensorflow/lite/delegates/flex/buffer_map_test.cc
+++ b/tensorflow/lite/delegates/flex/buffer_map_test.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/flex/buffer_map.h"
 
+#include <sys/types.h>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/lite/interpreter.h"
@@ -160,6 +162,32 @@
               ElementsAre("", "", "", "s3", "", "", "s1", "s2"));
 }
 
+TEST(BufferMapTest, SetFromTfLiteBuiltinResource) {
+  BufferMap buffer_map;
+
+  // Constructs a fake resource tensor.
+  auto tensor = UniqueTfLiteTensor(new TfLiteTensor(), [](TfLiteTensor* t) {
+    TfLiteTensorDataFree(t);
+    TfLiteIntArrayFree(t->dims);
+    delete t;
+  });
+  tensor->allocation_type = kTfLiteDynamic;
+  tensor->type = kTfLiteResource;
+  tensor->dims = ConvertVectorToTfLiteIntArray({1});
+  TfLiteTensorRealloc(sizeof(int32_t), tensor.get());
+  tensor->delegate = nullptr;
+  tensor->data.i32[0] = 1;
+
+  buffer_map.SetFromTfLite(0, tensor.get());
+  // Also check details of the tensor.
+  tensorflow::Tensor out_tensor = buffer_map.GetTensor(0);
+  ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_RESOURCE);
+  ASSERT_EQ(out_tensor.NumElements(), 1);
+  tensorflow::ResourceHandle handle =
+      out_tensor.flat<tensorflow::ResourceHandle>()(0);
+  EXPECT_EQ(handle.name(), "tflite_resource_variable:1");
+}
+
 TEST(BufferMapTest, SetFromTensorFlow) {
   tensorflow::Tensor t1 =
       MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
diff --git a/tensorflow/lite/delegates/flex/buffer_map_util.cc b/tensorflow/lite/delegates/flex/buffer_map_util.cc
index 2f01b99..27a027e 100644
--- a/tensorflow/lite/delegates/flex/buffer_map_util.cc
+++ b/tensorflow/lite/delegates/flex/buffer_map_util.cc
@@ -17,7 +17,10 @@
 #include "tensorflow/core/framework/log_memory.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/typed_allocator.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/delegates/flex/util.h"
+#include "tensorflow/lite/experimental/resource/resource_variable.h"
 #include "tensorflow/lite/string_util.h"
 
 namespace tflite {
@@ -100,19 +103,34 @@
 
 tensorflow::Status SetTfTensorFromTfLite(const TfLiteTensor* tensor,
                                          tensorflow::Tensor* tf_tensor) {
-  // TODO(b/179094265): This is an experimental implementation, subject to
-  // change. This can be re-implemented with life cycle management mechanism
-  // like reference counting.
-  // In a different subgraph, it can load the TensorFlow tensor pointer of the
-  // given TensorFlow Lite tensor, which is stored in the `data` field. The
-  // memory management cycle of the shared TensorFlow's tensor will be managed
-  // by the buffer maps since the loaded tensors always will be kept in the
-  // buffer map.
-  //
-  // The life cycle of the pointer will be managed by the reference counting in
-  // the TensorFlow world and the pointer will be freed when all the buffer
-  // maps, who own it, are gone.
-  if (IsResourceOrVariant(tensor)) {
+  if (resource::IsBuiltinResource(tensor)) {
+    // If this is native TF Lite resource variable, then we create a TF resource
+    // tensor where the tensor handle encodes the identifier of the TF Lite
+    // resource.
+    // This approach assumes that there is only a single model being invoked
+    // via the Interpreter instance, so that the resource IDs won't have any
+    // collisions. If we plan to support concurrent execution in the future, we
+    // should make sure the resource ID being encoded is unique between
+    // different executions.
+    tensorflow::Tensor t(tensorflow::DT_RESOURCE, tensorflow::TensorShape({}));
+    tensorflow::ResourceHandle handle;
+    handle.set_name(TfLiteResourceIdentifier(tensor));
+    t.flat<tensorflow::ResourceHandle>()(0) = handle;
+    *tf_tensor = t;
+    return tensorflow::Status::OK();
+  } else if (IsResourceOrVariant(tensor)) {
+    // TODO(b/179094265): This is an experimental implementation, subject to
+    // change. This can be re-implemented with life cycle management mechanism
+    // like reference counting.
+    // In a different subgraph, it can load the TensorFlow tensor pointer of the
+    // given TensorFlow Lite tensor, which is stored in the `data` field. The
+    // memory management cycle of the shared TensorFlow's tensor will be managed
+    // by the buffer maps since the loaded tensors always will be kept in the
+    // buffer map.
+    //
+    // The life cycle of the pointer will be managed by the reference counting
+    // in the TensorFlow world and the pointer will be freed when all the buffer
+    // maps, who own it, are gone.
     const tensorflow::Tensor** tf_tensor_ptr =
         reinterpret_cast<const tensorflow::Tensor**>(tensor->data.raw);
     *tf_tensor = **tf_tensor_ptr;
diff --git a/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc b/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
index fd160f7..44dbc21 100644
--- a/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
+++ b/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
@@ -14,10 +14,13 @@
 ==============================================================================*/
 #include <string>
 
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -30,6 +33,7 @@
 #include "tensorflow/lite/delegates/flex/buffer_map_util.h"
 #include "tensorflow/lite/delegates/flex/subgraph_resource.h"
 #include "tensorflow/lite/delegates/flex/util.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/string_util.h"
 
@@ -136,6 +140,27 @@
   void SetSubgraphInput(OpKernelContext* ctx,
                         tflite::Subgraph& subgraph_selected,
                         TfLiteDelegate* flex_delegate) const {
+    auto InitializeVariantOrResource = [flex_delegate](
+                                           const Tensor& tf_tensor,
+                                           TfLiteTensor* subgraph_input) {
+      // The code here initializes the TfLiteTensor which points the data field
+      // to the original TF resource or variant tensor. This requires the TF
+      // tensor's lifetime must extend beyond the execution of callee subgraph.
+      // TODO(b/179094265): This is an experimental implementation, subject to
+      // change. This can be re-implemented with life cycle management
+      // mechanism like reference counting.
+      const size_t required_bytes = sizeof(tensorflow::Tensor**);
+      const tensorflow::Tensor** tf_tensor_ptr =
+          reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes));
+      *tf_tensor_ptr = &tf_tensor;
+
+      TfLiteTensorDataFree(subgraph_input);
+      subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
+      subgraph_input->bytes = required_bytes;
+      subgraph_input->data_is_stale = true;
+      subgraph_input->delegate = flex_delegate;
+    };
+
     for (int i = 0; i < subgraph_selected.inputs().size(); ++i) {
       const Tensor& tf_tensor = ctx->input(i + 1);
       TfLiteTensor* subgraph_input =
@@ -151,21 +176,18 @@
         }
 
         dynamic_buffer.WriteToTensor(subgraph_input, /*new_shape=*/nullptr);
-      } else if (tflite::flex::IsResourceOrVariant(subgraph_input)) {
-        // TODO(b/179094265): This is an experimental implementation, subject to
-        // change. This can be re-implemented with life cycle management
-        // mechanism like reference counting.
-        const size_t required_bytes = sizeof(tensorflow::Tensor**);
-        const tensorflow::Tensor** tf_tensor_ptr =
-            reinterpret_cast<const tensorflow::Tensor**>(
-                malloc(required_bytes));
-        *tf_tensor_ptr = &tf_tensor;
-
-        TfLiteTensorDataFree(subgraph_input);
-        subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
-        subgraph_input->bytes = required_bytes;
-        subgraph_input->data_is_stale = false;
-        subgraph_input->delegate = flex_delegate;
+      } else if (subgraph_input->type == kTfLiteResource) {
+        // Here we will try to parse the input tensor handle to see if it
+        // contains a valid TF lite resource ID. If not, then we know that the
+        // input is a TF resource tensor.
+        tensorflow::ResourceHandle handle =
+            tf_tensor.flat<tensorflow::ResourceHandle>()(0);
+        if (!tflite::flex::GetTfLiteResourceTensorFromResourceHandle(
+                handle, subgraph_input)) {
+          InitializeVariantOrResource(tf_tensor, subgraph_input);
+        }
+      } else if (subgraph_input->type == kTfLiteVariant) {
+        InitializeVariantOrResource(tf_tensor, subgraph_input);
       } else {
         tensorflow::StringPiece tensor_data = tf_tensor.tensor_data();
         OP_REQUIRES(ctx, subgraph_input->bytes == tensor_data.size(),
@@ -190,6 +212,8 @@
           subgraph_selected.tensor(subgraph_selected.outputs()[i]);
 
       Tensor tensor;
+      fprintf(stdout, "1111111\n");
+      fflush(stdout);
       OP_REQUIRES_OK(
           ctx, tflite::flex::SetTfTensorFromTfLite(subgraph_output, &tensor));
       ctx->set_output(i, std::move(tensor));
diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc
index 209f9e6..685f85d 100644
--- a/tensorflow/lite/delegates/flex/util.cc
+++ b/tensorflow/lite/delegates/flex/util.cc
@@ -14,9 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/flex/util.h"
 
+#include "absl/strings/str_format.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+
 namespace tflite {
 namespace flex {
 
+static constexpr char kResourceVariablePrefix[] = "tflite_resource_variable";
+
 TfLiteStatus ConvertStatus(TfLiteContext* context,
                            const tensorflow::Status& status) {
   if (!status.ok()) {
@@ -171,5 +176,31 @@
   return "invalid";
 }
 
+std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor) {
+  // TODO(b/199782192): Create a util function to get Resource ID from a TF Lite
+  // resource tensor.
+  const int resource_id = tensor->data.i32[0];
+  return absl::StrFormat("%s:%d", kResourceVariablePrefix, resource_id);
+}
+
+bool GetTfLiteResourceTensorFromResourceHandle(
+    const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor) {
+  std::vector<std::string> parts = absl::StrSplit(resource_handle.name(), ':');
+  if (parts.size() != 2) {
+    return false;
+  }
+  const int kBytesRequired = sizeof(int32_t);
+  TfLiteTensorRealloc(kBytesRequired, tensor);
+  int resource_id;
+  if (parts[0] == kResourceVariablePrefix &&
+      absl::SimpleAtoi<int32_t>(parts[1], &resource_id)) {
+    // TODO(b/199782192): Create a util function to set the Resource ID of
+    // a TF Lite resource tensor.
+    GetTensorData<int32_t>(tensor)[0] = resource_id;
+    return true;
+  }
+  return false;
+}
+
 }  // namespace flex
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/flex/util.h b/tensorflow/lite/delegates/flex/util.h
index b29e765..2b2dedc 100644
--- a/tensorflow/lite/delegates/flex/util.h
+++ b/tensorflow/lite/delegates/flex/util.h
@@ -15,6 +15,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_
 #define TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_
 
+#include <string>
+
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -49,6 +51,16 @@
   return tensor->type == kTfLiteResource || tensor->type == kTfLiteVariant;
 }
 
+// Returns the encoded string name for a TF Lite resource variable tensor.
+// This function will return a string in the format:
+// tflite_resource_variable:resource_id.
+std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor);
+
+// Parses out the resource ID from the given `resource_handle` and sets it
+// to the corresponding TfLiteTensor. Returns true if succeed.
+bool GetTfLiteResourceTensorFromResourceHandle(
+    const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor);
+
 }  // namespace flex
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc
index 68a991e..8c5a9ca 100644
--- a/tensorflow/lite/delegates/flex/util_test.cc
+++ b/tensorflow/lite/delegates/flex/util_test.cc
@@ -18,8 +18,10 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/lite/string_type.h"
 #include "tensorflow/lite/testing/util.h"
+#include "tensorflow/lite/util.h"
 
 namespace tflite {
 namespace flex {
@@ -141,6 +143,40 @@
   EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT));
 }
 
+TEST(UtilTest, GetTfLiteResourceIdentifier) {
+  // Constructs a fake resource tensor.
+  TfLiteTensor tensor;
+  tensor.allocation_type = kTfLiteDynamic;
+  tensor.type = kTfLiteResource;
+  std::vector<int> dims = {1};
+  tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+  tensor.data.raw = nullptr;
+  TfLiteTensorRealloc(sizeof(int32_t), &tensor);
+  tensor.delegate = nullptr;
+  tensor.data.i32[0] = 1;
+
+  EXPECT_EQ(TfLiteResourceIdentifier(&tensor), "tflite_resource_variable:1");
+  TfLiteIntArrayFree(tensor.dims);
+  TfLiteTensorDataFree(&tensor);
+}
+
+TEST(UtilTest, GetTfLiteResourceTensorFromResourceHandle) {
+  tensorflow::ResourceHandle handle;
+  handle.set_name("tflite_resource_variable:1");
+
+  TfLiteTensor tensor;
+  tensor.allocation_type = kTfLiteDynamic;
+  tensor.type = kTfLiteResource;
+  tensor.data.raw = nullptr;
+  std::vector<int> dims = {1};
+  tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+  EXPECT_TRUE(GetTfLiteResourceTensorFromResourceHandle(handle, &tensor));
+  EXPECT_EQ(tensor.data.i32[0], 1);
+
+  TfLiteIntArrayFree(tensor.dims);
+  TfLiteTensorDataFree(&tensor);
+}
+
 }  // namespace
 }  // namespace flex
 }  // namespace tflite