Add DynamicUpdateSlice op to TFLite runtime. Supported data types are float and 8/32/64 bit int.
PiperOrigin-RevId: 428637477
Change-Id: I8c2d9373c1da9a4375af66a6f277df3737edbb54
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 17f4d1a..e9f6fa3 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -573,6 +573,7 @@
"dequantize.cc",
"detection_postprocess.cc",
"div.cc",
+ "dynamic_update_slice.cc",
"elementwise.cc",
"embedding_lookup.cc",
"embedding_lookup_sparse.cc",
@@ -2734,6 +2735,19 @@
],
)
+cc_test(
+ name = "dynamic_update_slice_test",
+ size = "small",
+ srcs = ["dynamic_update_slice_test.cc"],
+ deps = [
+ ":test_main",
+ ":test_util",
+ "//tensorflow/lite/schema:schema_fbs",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
# Note this is created mainly for validating external delegates in OSS.
diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc
new file mode 100644
index 0000000..7a9f7a6
--- /dev/null
+++ b/tensorflow/lite/kernels/dynamic_update_slice.cc
@@ -0,0 +1,190 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <vector>
+
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace dynamic_update_slice {
+
+constexpr int kOperandTensor = 0;
+constexpr int kUpdateTensor = 1;
+constexpr int kStartIndicesTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// TFLite DynamicUpdateSlice op follows the semantics of XLA DynamicUpdateSlice
+// op. See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
+// for details.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* operand;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kOperandTensor, &operand));
+ const TfLiteTensor* update;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kUpdateTensor, &update));
+ const TfLiteTensor* start_indices;
+ TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
+ &start_indices));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ GetOutputSafe(context, node, kOutputTensor, &output));
+
+ // The shape of start_indices must be rank == 1, with dimension size equal to
+ // the rank of operand.
+ TF_LITE_ENSURE(context, NumDimensions(start_indices) == 1);
+ TF_LITE_ENSURE(context,
+ SizeOfDimension(start_indices, 0) == NumDimensions(operand));
+
+ // Update must be less than or equal to the operand size for each dimension to
+ // avoid generating out-of-bounds update indices.
+ TF_LITE_ENSURE(context, NumDimensions(update) == NumDimensions(operand));
+ for (int i = 0; i < NumDimensions(operand); i++) {
+ TF_LITE_ENSURE(context,
+ SizeOfDimension(update, i) <= SizeOfDimension(operand, i));
+ }
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type);
+ TF_LITE_ENSURE_TYPES_EQ(context, start_indices->type, kTfLiteInt32);
+
+ output->type = operand->type;
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(operand->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+// A helper function that converts a tensor index into a flat array index.
+// Takes `start_indices` as an offset if not null.
+int TensorIndexToFlat(const int* index, const int dims,
+ const RuntimeShape& shape,
+ const int* start_indices = nullptr) {
+ int flat_index = index[0] + (start_indices ? start_indices[0] : 0);
+ for (int i = 1; i < dims; i++) {
+ flat_index = flat_index * shape.Dims(i) + index[i] +
+ (start_indices ? start_indices[i] : 0);
+ }
+ return flat_index;
+}
+
+// A helper function to compute the clamped start indices to ensure they are
+// not out of bounds.
+std::vector<int> ClampStartIndices(int input_dims, const int32_t* indices_data,
+ const RuntimeShape& input_shape,
+ const RuntimeShape& update_shape) {
+ std::vector<int> clamped_start_indices(input_dims, 0);
+ for (int i = 0; i < input_dims; i++) {
+ clamped_start_indices[i] =
+ std::min(std::max(0, indices_data[i]),
+ input_shape.Dims(i) - update_shape.Dims(i));
+ }
+ return clamped_start_indices;
+}
+
+template <typename T>
+void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update,
+ const TfLiteTensor* indice, TfLiteTensor* output) {
+ const auto& input_shape = GetTensorShape(input);
+ const auto& update_shape = GetTensorShape(update);
+ const T* update_data = GetTensorData<T>(update);
+ const int32_t* indices_data = GetTensorData<int32_t>(indice);
+ T* output_data = GetTensorData<T>(output);
+
+ const int input_dims = input_shape.DimensionsCount();
+ // Computes the effective slice indices.
+ // The clamped indices are gauranteed to >= 0 since update is less than or
+ // equal to the operand size for each dimension.
+ std::vector<int> clamped_start_indices =
+ ClampStartIndices(input_dims, indices_data, input_shape, update_shape);
+
+ // Copies input to output first.
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+
+ std::vector<int> current_dim(input_dims, 0);
+ // Overwrites update to output.
+ do {
+ int flat_update_index =
+ TensorIndexToFlat(current_dim.data(), input_dims, update_shape);
+ int flat_input_index =
+ TensorIndexToFlat(current_dim.data(), input_dims, input_shape,
+ clamped_start_indices.data());
+ output_data[flat_input_index] = update_data[flat_update_index];
+ } while (NextIndex(input_dims,
+ reinterpret_cast<const int*>(update_shape.DimsData()),
+ current_dim.data()));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* operand;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kOperandTensor, &operand));
+ const TfLiteTensor* update;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kUpdateTensor, &update));
+ const TfLiteTensor* indice;
+ TF_LITE_ENSURE_OK(context,
+ GetInputSafe(context, node, kStartIndicesTensor, &indice));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ GetOutputSafe(context, node, kOutputTensor, &output));
+
+ switch (operand->type) {
+ case kTfLiteFloat32:
+ DynamicUpdateSlice<float>(operand, update, indice, output);
+ break;
+ case kTfLiteInt8:
+ DynamicUpdateSlice<int8_t>(operand, update, indice, output);
+ break;
+ case kTfLiteInt32:
+ DynamicUpdateSlice<int32_t>(operand, update, indice, output);
+ break;
+ case kTfLiteInt64:
+ DynamicUpdateSlice<int64_t>(operand, update, indice, output);
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(context,
+ "DynamicUpdateSlice only currently supports "
+ "8-bit/32-bit/64-bit integer or "
+ "float type, got %d.",
+ operand->type);
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+} // namespace dynamic_update_slice
+
+TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE() {
+ static TfLiteRegistration r = {/*init=*/nullptr,
+ /*free=*/nullptr,
+ dynamic_update_slice::Prepare,
+ dynamic_update_slice::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/kernels/dynamic_update_slice_test.cc
new file mode 100644
index 0000000..be32c9a
--- /dev/null
+++ b/tensorflow/lite/kernels/dynamic_update_slice_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdint.h>
+
+#include <algorithm>
+#include <initializer_list>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class DynamicUpdateSliceOpModel : public SingleOpModel {
+ public:
+ DynamicUpdateSliceOpModel(const TensorData& operand, const TensorData& update,
+ const TensorData& start_indices) {
+ input_ = AddInput(operand);
+ update_ = AddInput(update);
+ start_indices_ = AddInput(start_indices);
+ output_ = AddOutput(operand.type);
+ SetBuiltinOp(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+ BuiltinOptions_DynamicUpdateSliceOptions,
+ CreateDynamicUpdateSliceOptions(builder_).Union());
+ BuildInterpreter(
+ {GetShape(input_), GetShape(update_), GetShape(start_indices_)});
+ }
+
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+
+ template <typename T>
+ void SetUpdate(std::initializer_list<T> data) {
+ PopulateTensor<T>(update_, data);
+ }
+
+ void SetStringInput(std::initializer_list<string> data) {
+ PopulateStringTensor(input_, data);
+ }
+
+ template <typename T>
+ void SetStartIndices(std::initializer_list<T> data) {
+ PopulateTensor<T>(start_indices_, data);
+ }
+
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<string> GetStringOutput() {
+ return ExtractVector<string>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int update_;
+ int start_indices_;
+ int output_;
+};
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestF32) {
+ DynamicUpdateSliceOpModel m({TensorType_FLOAT32, {3, 3}},
+ {TensorType_FLOAT32, {2, 1}},
+ {TensorType_INT32, {2}});
+ m.SetInput<float>({1, 2, 3, //
+ 4, 5, 6, //
+ 7, 8, 9});
+ m.SetUpdate<float>({-1, -2});
+ m.SetStartIndices<int32_t>({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 2, 3, //
+ 4, -1, 6, //
+ 7, -2, 9})));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI8) {
+ DynamicUpdateSliceOpModel m({TensorType_INT8, {3, 3}},
+ {TensorType_INT8, {2, 1}},
+ {TensorType_INT32, {2}});
+ m.SetInput<int8_t>({1, 2, 3, //
+ 4, 5, 6, //
+ 7, 8, 9});
+ m.SetUpdate<int8_t>({-1, -2});
+ m.SetStartIndices<int32_t>({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({1, 2, 3, //
+ 4, -1, 6, //
+ 7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI32) {
+ DynamicUpdateSliceOpModel m({TensorType_INT32, {3, 3}},
+ {TensorType_INT32, {2, 1}},
+ {TensorType_INT32, {2}});
+ m.SetInput<int32_t>({1, 2, 3, //
+ 4, 5, 6, //
+ 7, 8, 9});
+ m.SetUpdate<int32_t>({-1, -2});
+ m.SetStartIndices<int32_t>({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({1, 2, 3, //
+ 4, -1, 6, //
+ 7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI64) {
+ DynamicUpdateSliceOpModel m({TensorType_INT64, {3, 3}},
+ {TensorType_INT64, {2, 1}},
+ {TensorType_INT32, {2}});
+ m.SetInput<int64_t>({1, 2, 3, //
+ 4, 5, 6, //
+ 7, 8, 9});
+ m.SetUpdate<int64_t>({-1, -2});
+ m.SetStartIndices<int32_t>({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({1, 2, 3, //
+ 4, -1, 6, //
+ 7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, BoundaryTest) {
+ DynamicUpdateSliceOpModel m({TensorType_FLOAT32, {3, 3}},
+ {TensorType_FLOAT32, {2, 2}},
+ {TensorType_INT32, {2}});
+ m.SetInput<float>({1, 2, 3, //
+ 4, 5, 6, //
+ 7, 8, 9});
+ m.SetUpdate<float>({-1, -2, //
+ -3, -4});
+ m.SetStartIndices<int32_t>({2, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 2, 3, //
+ 4, -1, -2, //
+ 7, -3, -4})));
+}
+
+TEST(DynamicUpdateSliceOpTest, UpdateShapeTooLargeTest) {
+ EXPECT_DEATH_IF_SUPPORTED(
+ DynamicUpdateSliceOpModel({TensorType_FLOAT32, {3, 3}},
+ {TensorType_FLOAT32, {4, 2}},
+ {TensorType_INT32, {2}}),
+ "SizeOfDimension\\(update, i\\) <= SizeOfDimension\\(operand, "
+ "i\\) was not true.");
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index df5ba5b..f30602f 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -340,6 +340,8 @@
AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
/* min_version = */ 1,
/* max_version = */ 2);
+ AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+ Register_DYNAMIC_UPDATE_SLICE());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc
index 8c81557..eb859e2 100644
--- a/tensorflow/lite/kernels/register_ref.cc
+++ b/tensorflow/lite/kernels/register_ref.cc
@@ -169,6 +169,7 @@
TfLiteRegistration* Register_RANDOM_UNIFORM();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_GELU();
+TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE();
namespace {
@@ -494,6 +495,8 @@
AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
/* min_version = */ 1,
/* max_version = */ 2);
+ AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+ Register_DYNAMIC_UPDATE_SLICE());
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index 0d6c8a0..16563d9 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -373,6 +373,7 @@
{{BuiltinOperator_MULTINOMIAL, 1}, "2.8.0"},
{{BuiltinOperator_GELU, 1}, "2.9.0"},
{{BuiltinOperator_GELU, 2}, "2.9.0"},
+ {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 1}, "2.9.0"},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};