Support passing native tflite resource variables with dataset ops.
Add runtime code to convert between TF Lite native resource tensor and TF resource tensor.
PiperOrigin-RevId: 410397929
Change-Id: If0c486c6c064daceaf30442fd2e7495fa1473e48
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 6f14fec..81e3af8 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -64,6 +64,7 @@
":util",
"//tensorflow/lite/c:common",
"//tensorflow/lite:string_util",
+ "//tensorflow/lite/experimental/resource",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
@@ -324,6 +325,8 @@
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite:kernel_api",
+ "@com_google_absl//absl/strings:str_format",
+ "//tensorflow/lite/kernels/internal:tensor",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
@@ -345,7 +348,9 @@
srcs = ["util_test.cc"],
deps = [
":util",
+ "//tensorflow/core:framework",
"//tensorflow/lite:string",
+ "//tensorflow/lite:util",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest_main",
],
@@ -421,6 +426,9 @@
":buffer_map_util",
":subgraph_resource",
":util",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/lite/kernels/internal:tensor",
+ "@com_google_absl//absl/strings:str_format",
"//tensorflow/lite:cc_api",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_types",
diff --git a/tensorflow/lite/delegates/flex/buffer_map_test.cc b/tensorflow/lite/delegates/flex/buffer_map_test.cc
index 6b0d5ef..3e7c087 100644
--- a/tensorflow/lite/delegates/flex/buffer_map_test.cc
+++ b/tensorflow/lite/delegates/flex/buffer_map_test.cc
@@ -14,6 +14,8 @@
==============================================================================*/
#include "tensorflow/lite/delegates/flex/buffer_map.h"
+#include <sys/types.h>
+
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
@@ -160,6 +162,32 @@
ElementsAre("", "", "", "s3", "", "", "s1", "s2"));
}
+TEST(BufferMapTest, SetFromTfLiteBuiltinResource) {
+ BufferMap buffer_map;
+
+ // Constructs a fake resource tensor.
+ auto tensor = UniqueTfLiteTensor(new TfLiteTensor(), [](TfLiteTensor* t) {
+ TfLiteTensorDataFree(t);
+ TfLiteIntArrayFree(t->dims);
+ delete t;
+ });
+ tensor->allocation_type = kTfLiteDynamic;
+ tensor->type = kTfLiteResource;
+ tensor->dims = ConvertVectorToTfLiteIntArray({1});
+ TfLiteTensorRealloc(sizeof(int32_t), tensor.get());
+ tensor->delegate = nullptr;
+ tensor->data.i32[0] = 1;
+
+ buffer_map.SetFromTfLite(0, tensor.get());
+ // Also check details of the tensor.
+ tensorflow::Tensor out_tensor = buffer_map.GetTensor(0);
+ ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_RESOURCE);
+ ASSERT_EQ(out_tensor.NumElements(), 1);
+ tensorflow::ResourceHandle handle =
+ out_tensor.flat<tensorflow::ResourceHandle>()(0);
+ EXPECT_EQ(handle.name(), "tflite_resource_variable:1");
+}
+
TEST(BufferMapTest, SetFromTensorFlow) {
tensorflow::Tensor t1 =
MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
diff --git a/tensorflow/lite/delegates/flex/buffer_map_util.cc b/tensorflow/lite/delegates/flex/buffer_map_util.cc
index 2f01b99..27a027e 100644
--- a/tensorflow/lite/delegates/flex/buffer_map_util.cc
+++ b/tensorflow/lite/delegates/flex/buffer_map_util.cc
@@ -17,7 +17,10 @@
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/typed_allocator.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/delegates/flex/util.h"
+#include "tensorflow/lite/experimental/resource/resource_variable.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
@@ -100,19 +103,34 @@
tensorflow::Status SetTfTensorFromTfLite(const TfLiteTensor* tensor,
tensorflow::Tensor* tf_tensor) {
- // TODO(b/179094265): This is an experimental implementation, subject to
- // change. This can be re-implemented with life cycle management mechanism
- // like reference counting.
- // In a different subgraph, it can load the TensorFlow tensor pointer of the
- // given TensorFlow Lite tensor, which is stored in the `data` field. The
- // memory management cycle of the shared TensorFlow's tensor will be managed
- // by the buffer maps since the loaded tensors always will be kept in the
- // buffer map.
- //
- // The life cycle of the pointer will be managed by the reference counting in
- // the TensorFlow world and the pointer will be freed when all the buffer
- // maps, who own it, are gone.
- if (IsResourceOrVariant(tensor)) {
+ if (resource::IsBuiltinResource(tensor)) {
+ // If this is native TF Lite resource variable, then we create a TF resource
+ // tensor where the tensor handle encodes the identifier of the TF Lite
+ // resource.
+ // This approach assumes that there is only a single model being invoked
+ // via the Interpreter instance, so that the resource IDs won't have any
+ // collisions. If we plan to support concurrent execution in the future, we
+ // should make sure the resource ID being encoded is unique between
+ // different executions.
+ tensorflow::Tensor t(tensorflow::DT_RESOURCE, tensorflow::TensorShape({}));
+ tensorflow::ResourceHandle handle;
+ handle.set_name(TfLiteResourceIdentifier(tensor));
+ t.flat<tensorflow::ResourceHandle>()(0) = handle;
+ *tf_tensor = t;
+ return tensorflow::Status::OK();
+ } else if (IsResourceOrVariant(tensor)) {
+ // TODO(b/179094265): This is an experimental implementation, subject to
+ // change. This can be re-implemented with life cycle management mechanism
+ // like reference counting.
+ // In a different subgraph, it can load the TensorFlow tensor pointer of the
+ // given TensorFlow Lite tensor, which is stored in the `data` field. The
+ // memory management cycle of the shared TensorFlow's tensor will be managed
+ // by the buffer maps since the loaded tensors always will be kept in the
+ // buffer map.
+ //
+ // The life cycle of the pointer will be managed by the reference counting
+ // in the TensorFlow world and the pointer will be freed when all the buffer
+ // maps, who own it, are gone.
const tensorflow::Tensor** tf_tensor_ptr =
reinterpret_cast<const tensorflow::Tensor**>(tensor->data.raw);
*tf_tensor = **tf_tensor_ptr;
diff --git a/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc b/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
index fd160f7..44dbc21 100644
--- a/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
+++ b/tensorflow/lite/delegates/flex/tflite_subgraph_execute.cc
@@ -14,10 +14,13 @@
==============================================================================*/
#include <string>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -30,6 +33,7 @@
#include "tensorflow/lite/delegates/flex/buffer_map_util.h"
#include "tensorflow/lite/delegates/flex/subgraph_resource.h"
#include "tensorflow/lite/delegates/flex/util.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
@@ -136,6 +140,27 @@
void SetSubgraphInput(OpKernelContext* ctx,
tflite::Subgraph& subgraph_selected,
TfLiteDelegate* flex_delegate) const {
+ auto InitializeVariantOrResource = [flex_delegate](
+ const Tensor& tf_tensor,
+ TfLiteTensor* subgraph_input) {
+ // The code here initializes the TfLiteTensor which points the data field
+ // to the original TF resource or variant tensor. This requires the TF
+ // tensor's lifetime must extend beyond the execution of callee subgraph.
+ // TODO(b/179094265): This is an experimental implementation, subject to
+ // change. This can be re-implemented with life cycle management
+ // mechanism like reference counting.
+ const size_t required_bytes = sizeof(tensorflow::Tensor**);
+ const tensorflow::Tensor** tf_tensor_ptr =
+ reinterpret_cast<const tensorflow::Tensor**>(malloc(required_bytes));
+ *tf_tensor_ptr = &tf_tensor;
+
+ TfLiteTensorDataFree(subgraph_input);
+ subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
+ subgraph_input->bytes = required_bytes;
+ subgraph_input->data_is_stale = true;
+ subgraph_input->delegate = flex_delegate;
+ };
+
for (int i = 0; i < subgraph_selected.inputs().size(); ++i) {
const Tensor& tf_tensor = ctx->input(i + 1);
TfLiteTensor* subgraph_input =
@@ -151,21 +176,18 @@
}
dynamic_buffer.WriteToTensor(subgraph_input, /*new_shape=*/nullptr);
- } else if (tflite::flex::IsResourceOrVariant(subgraph_input)) {
- // TODO(b/179094265): This is an experimental implementation, subject to
- // change. This can be re-implemented with life cycle management
- // mechanism like reference counting.
- const size_t required_bytes = sizeof(tensorflow::Tensor**);
- const tensorflow::Tensor** tf_tensor_ptr =
- reinterpret_cast<const tensorflow::Tensor**>(
- malloc(required_bytes));
- *tf_tensor_ptr = &tf_tensor;
-
- TfLiteTensorDataFree(subgraph_input);
- subgraph_input->data.raw = reinterpret_cast<char*>(tf_tensor_ptr);
- subgraph_input->bytes = required_bytes;
- subgraph_input->data_is_stale = false;
- subgraph_input->delegate = flex_delegate;
+ } else if (subgraph_input->type == kTfLiteResource) {
+ // Here we will try to parse the input tensor handle to see if it
+ // contains a valid TF lite resource ID. If not, then we know that the
+ // input is a TF resource tensor.
+ tensorflow::ResourceHandle handle =
+ tf_tensor.flat<tensorflow::ResourceHandle>()(0);
+ if (!tflite::flex::GetTfLiteResourceTensorFromResourceHandle(
+ handle, subgraph_input)) {
+ InitializeVariantOrResource(tf_tensor, subgraph_input);
+ }
+ } else if (subgraph_input->type == kTfLiteVariant) {
+ InitializeVariantOrResource(tf_tensor, subgraph_input);
} else {
tensorflow::StringPiece tensor_data = tf_tensor.tensor_data();
OP_REQUIRES(ctx, subgraph_input->bytes == tensor_data.size(),
@@ -190,6 +212,8 @@
subgraph_selected.tensor(subgraph_selected.outputs()[i]);
Tensor tensor;
+ fprintf(stdout, "1111111\n");
+ fflush(stdout);
OP_REQUIRES_OK(
ctx, tflite::flex::SetTfTensorFromTfLite(subgraph_output, &tensor));
ctx->set_output(i, std::move(tensor));
diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc
index 209f9e6..685f85d 100644
--- a/tensorflow/lite/delegates/flex/util.cc
+++ b/tensorflow/lite/delegates/flex/util.cc
@@ -14,9 +14,14 @@
==============================================================================*/
#include "tensorflow/lite/delegates/flex/util.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+
namespace tflite {
namespace flex {
+static constexpr char kResourceVariablePrefix[] = "tflite_resource_variable";
+
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
if (!status.ok()) {
@@ -171,5 +176,31 @@
return "invalid";
}
+std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor) {
+ // TODO(b/199782192): Create a util function to get Resource ID from a TF Lite
+ // resource tensor.
+ const int resource_id = tensor->data.i32[0];
+ return absl::StrFormat("%s:%d", kResourceVariablePrefix, resource_id);
+}
+
+bool GetTfLiteResourceTensorFromResourceHandle(
+ const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor) {
+ std::vector<std::string> parts = absl::StrSplit(resource_handle.name(), ':');
+ if (parts.size() != 2) {
+ return false;
+ }
+ const int kBytesRequired = sizeof(int32_t);
+ TfLiteTensorRealloc(kBytesRequired, tensor);
+ int resource_id;
+ if (parts[0] == kResourceVariablePrefix &&
+ absl::SimpleAtoi<int32_t>(parts[1], &resource_id)) {
+ // TODO(b/199782192): Create a util function to set the Resource ID of
+ // a TF Lite resource tensor.
+ GetTensorData<int32_t>(tensor)[0] = resource_id;
+ return true;
+ }
+ return false;
+}
+
} // namespace flex
} // namespace tflite
diff --git a/tensorflow/lite/delegates/flex/util.h b/tensorflow/lite/delegates/flex/util.h
index b29e765..2b2dedc 100644
--- a/tensorflow/lite/delegates/flex/util.h
+++ b/tensorflow/lite/delegates/flex/util.h
@@ -15,6 +15,8 @@
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_
+#include <string>
+
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -49,6 +51,16 @@
return tensor->type == kTfLiteResource || tensor->type == kTfLiteVariant;
}
+// Returns the encoded string name for a TF Lite resource variable tensor.
+// This function will return a string in the format:
+// tflite_resource_variable:resource_id.
+std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor);
+
+// Parses out the resource ID from the given `resource_handle` and sets it
+// to the corresponding TfLiteTensor. Returns true if succeed.
+bool GetTfLiteResourceTensorFromResourceHandle(
+ const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor);
+
} // namespace flex
} // namespace tflite
diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc
index 68a991e..8c5a9ca 100644
--- a/tensorflow/lite/delegates/flex/util_test.cc
+++ b/tensorflow/lite/delegates/flex/util_test.cc
@@ -18,8 +18,10 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/lite/string_type.h"
#include "tensorflow/lite/testing/util.h"
+#include "tensorflow/lite/util.h"
namespace tflite {
namespace flex {
@@ -141,6 +143,40 @@
EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT));
}
+TEST(UtilTest, GetTfLiteResourceIdentifier) {
+ // Constructs a fake resource tensor.
+ TfLiteTensor tensor;
+ tensor.allocation_type = kTfLiteDynamic;
+ tensor.type = kTfLiteResource;
+ std::vector<int> dims = {1};
+ tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+ tensor.data.raw = nullptr;
+ TfLiteTensorRealloc(sizeof(int32_t), &tensor);
+ tensor.delegate = nullptr;
+ tensor.data.i32[0] = 1;
+
+ EXPECT_EQ(TfLiteResourceIdentifier(&tensor), "tflite_resource_variable:1");
+ TfLiteIntArrayFree(tensor.dims);
+ TfLiteTensorDataFree(&tensor);
+}
+
+TEST(UtilTest, GetTfLiteResourceTensorFromResourceHandle) {
+ tensorflow::ResourceHandle handle;
+ handle.set_name("tflite_resource_variable:1");
+
+ TfLiteTensor tensor;
+ tensor.allocation_type = kTfLiteDynamic;
+ tensor.type = kTfLiteResource;
+ tensor.data.raw = nullptr;
+ std::vector<int> dims = {1};
+ tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+ EXPECT_TRUE(GetTfLiteResourceTensorFromResourceHandle(handle, &tensor));
+ EXPECT_EQ(tensor.data.i32[0], 1);
+
+ TfLiteIntArrayFree(tensor.dims);
+ TfLiteTensorDataFree(&tensor);
+}
+
} // namespace
} // namespace flex
} // namespace tflite