Add DynamicUpdateSlice op to TFLite runtime. Supported data types are float and 8/32/64 bit int.

PiperOrigin-RevId: 428637477
Change-Id: I8c2d9373c1da9a4375af66a6f277df3737edbb54
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 17f4d1a..e9f6fa3 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -573,6 +573,7 @@
     "dequantize.cc",
     "detection_postprocess.cc",
     "div.cc",
+    "dynamic_update_slice.cc",
     "elementwise.cc",
     "embedding_lookup.cc",
     "embedding_lookup_sparse.cc",
@@ -2734,6 +2735,19 @@
     ],
 )
 
+cc_test(
+    name = "dynamic_update_slice_test",
+    size = "small",
+    srcs = ["dynamic_update_slice_test.cc"],
+    deps = [
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_googletest//:gtest",
+        "@flatbuffers",
+    ],
+)
+
 tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
 
 # Note this is created mainly for validating external delegates in OSS.
diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc
new file mode 100644
index 0000000..7a9f7a6
--- /dev/null
+++ b/tensorflow/lite/kernels/dynamic_update_slice.cc
@@ -0,0 +1,190 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <vector>
+
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace dynamic_update_slice {
+
+constexpr int kOperandTensor = 0;
+constexpr int kUpdateTensor = 1;
+constexpr int kStartIndicesTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// TFLite DynamicUpdateSlice op follows the semantics of XLA DynamicUpdateSlice
+// op. See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
+// for details.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* operand;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kOperandTensor, &operand));
+  const TfLiteTensor* update;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kUpdateTensor, &update));
+  const TfLiteTensor* start_indices;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
+                                          &start_indices));
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+
+  // The shape of start_indices must be rank == 1, with dimension size equal to
+  // the rank of operand.
+  TF_LITE_ENSURE(context, NumDimensions(start_indices) == 1);
+  TF_LITE_ENSURE(context,
+                 SizeOfDimension(start_indices, 0) == NumDimensions(operand));
+
+  // Update must be less than or equal to the operand size for each dimension to
+  // avoid generating out-of-bounds update indices.
+  TF_LITE_ENSURE(context, NumDimensions(update) == NumDimensions(operand));
+  for (int i = 0; i < NumDimensions(operand); i++) {
+    TF_LITE_ENSURE(context,
+                   SizeOfDimension(update, i) <= SizeOfDimension(operand, i));
+  }
+
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type);
+  TF_LITE_ENSURE_TYPES_EQ(context, start_indices->type, kTfLiteInt32);
+
+  output->type = operand->type;
+  TfLiteIntArray* output_size = TfLiteIntArrayCopy(operand->dims);
+  return context->ResizeTensor(context, output, output_size);
+}
+
+// A helper function that converts a tensor index into a flat array index.
+// Takes `start_indices` as an offset if not null.
+int TensorIndexToFlat(const int* index, const int dims,
+                      const RuntimeShape& shape,
+                      const int* start_indices = nullptr) {
+  int flat_index = index[0] + (start_indices ? start_indices[0] : 0);
+  for (int i = 1; i < dims; i++) {
+    flat_index = flat_index * shape.Dims(i) + index[i] +
+                 (start_indices ? start_indices[i] : 0);
+  }
+  return flat_index;
+}
+
+// A helper function to compute the clamped start indices to ensure they are
+// not out of bounds.
+std::vector<int> ClampStartIndices(int input_dims, const int32_t* indices_data,
+                                   const RuntimeShape& input_shape,
+                                   const RuntimeShape& update_shape) {
+  std::vector<int> clamped_start_indices(input_dims, 0);
+  for (int i = 0; i < input_dims; i++) {
+    clamped_start_indices[i] =
+        std::min(std::max(0, indices_data[i]),
+                 input_shape.Dims(i) - update_shape.Dims(i));
+  }
+  return clamped_start_indices;
+}
+
+template <typename T>
+void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update,
+                        const TfLiteTensor* indice, TfLiteTensor* output) {
+  const auto& input_shape = GetTensorShape(input);
+  const auto& update_shape = GetTensorShape(update);
+  const T* update_data = GetTensorData<T>(update);
+  const int32_t* indices_data = GetTensorData<int32_t>(indice);
+  T* output_data = GetTensorData<T>(output);
+
+  const int input_dims = input_shape.DimensionsCount();
+  // Computes the effective slice indices.
+  // The clamped indices are gauranteed to >= 0 since update is less than or
+  // equal to the operand size for each dimension.
+  std::vector<int> clamped_start_indices =
+      ClampStartIndices(input_dims, indices_data, input_shape, update_shape);
+
+  // Copies input to output first.
+  memcpy(output->data.raw, input->data.raw, input->bytes);
+
+  std::vector<int> current_dim(input_dims, 0);
+  // Overwrites update to output.
+  do {
+    int flat_update_index =
+        TensorIndexToFlat(current_dim.data(), input_dims, update_shape);
+    int flat_input_index =
+        TensorIndexToFlat(current_dim.data(), input_dims, input_shape,
+                          clamped_start_indices.data());
+    output_data[flat_input_index] = update_data[flat_update_index];
+  } while (NextIndex(input_dims,
+                     reinterpret_cast<const int*>(update_shape.DimsData()),
+                     current_dim.data()));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* operand;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kOperandTensor, &operand));
+  const TfLiteTensor* update;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kUpdateTensor, &update));
+  const TfLiteTensor* indice;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kStartIndicesTensor, &indice));
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+
+  switch (operand->type) {
+    case kTfLiteFloat32:
+      DynamicUpdateSlice<float>(operand, update, indice, output);
+      break;
+    case kTfLiteInt8:
+      DynamicUpdateSlice<int8_t>(operand, update, indice, output);
+      break;
+    case kTfLiteInt32:
+      DynamicUpdateSlice<int32_t>(operand, update, indice, output);
+      break;
+    case kTfLiteInt64:
+      DynamicUpdateSlice<int64_t>(operand, update, indice, output);
+      break;
+    default:
+      TF_LITE_KERNEL_LOG(context,
+                         "DynamicUpdateSlice only currently supports "
+                         "8-bit/32-bit/64-bit integer or "
+                         "float type, got %d.",
+                         operand->type);
+      return kTfLiteError;
+  }
+
+  return kTfLiteOk;
+}
+}  // namespace dynamic_update_slice
+
+TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE() {
+  static TfLiteRegistration r = {/*init=*/nullptr,
+                                 /*free=*/nullptr,
+                                 dynamic_update_slice::Prepare,
+                                 dynamic_update_slice::Eval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/kernels/dynamic_update_slice_test.cc
new file mode 100644
index 0000000..be32c9a
--- /dev/null
+++ b/tensorflow/lite/kernels/dynamic_update_slice_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdint.h>
+
+#include <algorithm>
+#include <initializer_list>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class DynamicUpdateSliceOpModel : public SingleOpModel {
+ public:
+  DynamicUpdateSliceOpModel(const TensorData& operand, const TensorData& update,
+                            const TensorData& start_indices) {
+    input_ = AddInput(operand);
+    update_ = AddInput(update);
+    start_indices_ = AddInput(start_indices);
+    output_ = AddOutput(operand.type);
+    SetBuiltinOp(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+                 BuiltinOptions_DynamicUpdateSliceOptions,
+                 CreateDynamicUpdateSliceOptions(builder_).Union());
+    BuildInterpreter(
+        {GetShape(input_), GetShape(update_), GetShape(start_indices_)});
+  }
+
+  template <typename T>
+  void SetInput(std::initializer_list<T> data) {
+    PopulateTensor<T>(input_, data);
+  }
+
+  template <typename T>
+  void SetUpdate(std::initializer_list<T> data) {
+    PopulateTensor<T>(update_, data);
+  }
+
+  void SetStringInput(std::initializer_list<string> data) {
+    PopulateStringTensor(input_, data);
+  }
+
+  template <typename T>
+  void SetStartIndices(std::initializer_list<T> data) {
+    PopulateTensor<T>(start_indices_, data);
+  }
+
+  template <typename T>
+  std::vector<T> GetOutput() {
+    return ExtractVector<T>(output_);
+  }
+
+  std::vector<string> GetStringOutput() {
+    return ExtractVector<string>(output_);
+  }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+  int input_;
+  int update_;
+  int start_indices_;
+  int output_;
+};
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestF32) {
+  DynamicUpdateSliceOpModel m({TensorType_FLOAT32, {3, 3}},
+                              {TensorType_FLOAT32, {2, 1}},
+                              {TensorType_INT32, {2}});
+  m.SetInput<float>({1, 2, 3,  //
+                     4, 5, 6,  //
+                     7, 8, 9});
+  m.SetUpdate<float>({-1, -2});
+  m.SetStartIndices<int32_t>({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<float>(),
+              ElementsAreArray(ArrayFloatNear({1, 2, 3,   //
+                                               4, -1, 6,  //
+                                               7, -2, 9})));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI8) {
+  DynamicUpdateSliceOpModel m({TensorType_INT8, {3, 3}},
+                              {TensorType_INT8, {2, 1}},
+                              {TensorType_INT32, {2}});
+  m.SetInput<int8_t>({1, 2, 3,  //
+                      4, 5, 6,  //
+                      7, 8, 9});
+  m.SetUpdate<int8_t>({-1, -2});
+  m.SetStartIndices<int32_t>({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({1, 2, 3,   //
+                                                       4, -1, 6,  //
+                                                       7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI32) {
+  DynamicUpdateSliceOpModel m({TensorType_INT32, {3, 3}},
+                              {TensorType_INT32, {2, 1}},
+                              {TensorType_INT32, {2}});
+  m.SetInput<int32_t>({1, 2, 3,  //
+                       4, 5, 6,  //
+                       7, 8, 9});
+  m.SetUpdate<int32_t>({-1, -2});
+  m.SetStartIndices<int32_t>({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({1, 2, 3,   //
+                                                        4, -1, 6,  //
+                                                        7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, SimpleTestI64) {
+  DynamicUpdateSliceOpModel m({TensorType_INT64, {3, 3}},
+                              {TensorType_INT64, {2, 1}},
+                              {TensorType_INT32, {2}});
+  m.SetInput<int64_t>({1, 2, 3,  //
+                       4, 5, 6,  //
+                       7, 8, 9});
+  m.SetUpdate<int64_t>({-1, -2});
+  m.SetStartIndices<int32_t>({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({1, 2, 3,   //
+                                                        4, -1, 6,  //
+                                                        7, -2, 9}));
+}
+
+TEST(DynamicUpdateSliceOpTest, BoundaryTest) {
+  DynamicUpdateSliceOpModel m({TensorType_FLOAT32, {3, 3}},
+                              {TensorType_FLOAT32, {2, 2}},
+                              {TensorType_INT32, {2}});
+  m.SetInput<float>({1, 2, 3,  //
+                     4, 5, 6,  //
+                     7, 8, 9});
+  m.SetUpdate<float>({-1, -2,  //
+                      -3, -4});
+  m.SetStartIndices<int32_t>({2, 2});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<float>(),
+              ElementsAreArray(ArrayFloatNear({1, 2, 3,    //
+                                               4, -1, -2,  //
+                                               7, -3, -4})));
+}
+
+TEST(DynamicUpdateSliceOpTest, UpdateShapeTooLargeTest) {
+  EXPECT_DEATH_IF_SUPPORTED(
+      DynamicUpdateSliceOpModel({TensorType_FLOAT32, {3, 3}},
+                                {TensorType_FLOAT32, {4, 2}},
+                                {TensorType_INT32, {2}}),
+      "SizeOfDimension\\(update, i\\) <= SizeOfDimension\\(operand, "
+      "i\\) was not true.");
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index df5ba5b..f30602f 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -340,6 +340,8 @@
   AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
              /* min_version = */ 1,
              /* max_version = */ 2);
+  AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+             Register_DYNAMIC_UPDATE_SLICE());
   AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
   // custom ops aren't always included by default.
diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc
index 8c81557..eb859e2 100644
--- a/tensorflow/lite/kernels/register_ref.cc
+++ b/tensorflow/lite/kernels/register_ref.cc
@@ -169,6 +169,7 @@
 TfLiteRegistration* Register_RANDOM_UNIFORM();
 TfLiteRegistration* Register_MULTINOMIAL();
 TfLiteRegistration* Register_GELU();
+TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE();
 
 namespace {
 
@@ -494,6 +495,8 @@
   AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
              /* min_version = */ 1,
              /* max_version = */ 2);
+  AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+             Register_DYNAMIC_UPDATE_SLICE());
   AddCustom("NumericVerify",
             tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index 0d6c8a0..16563d9 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -373,6 +373,7 @@
               {{BuiltinOperator_MULTINOMIAL, 1}, "2.8.0"},
               {{BuiltinOperator_GELU, 1}, "2.9.0"},
               {{BuiltinOperator_GELU, 2}, "2.9.0"},
+              {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 1}, "2.9.0"},
           });
 
   std::pair<BuiltinOperator, int> version_key = {op_code, op_version};