Added ScatterND op
PiperOrigin-RevId: 278286268
Change-Id: Iacceb4ac18b7eb2f8da2faa385d0ca69159ffa1c
diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h
index 918f724..18c9b32 100644
--- a/tensorflow/lite/builtin_ops.h
+++ b/tensorflow/lite/builtin_ops.h
@@ -148,6 +148,7 @@
kTfLiteBuiltinWhile = 119,
kTfLiteBuiltinNonMaxSuppressionV4 = 120,
kTfLiteBuiltinNonMaxSuppressionV5 = 121,
+ kTfLiteBuiltinScatterNd = 122,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index a379cd8..e164f18 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -818,6 +818,7 @@
case BuiltinOperator_QUANTIZE:
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
+ case BuiltinOperator_SCATTER_ND:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index d6427eb..88492d3 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -482,6 +482,7 @@
"reverse.cc",
"reverse_sequence.cc",
"round.cc",
+ "scatter_nd.cc",
"select.cc",
"shape.cc",
"skip_gram.cc",
@@ -1152,6 +1153,20 @@
)
cc_test(
+ name = "scatter_nd_test",
+ size = "small",
+ srcs = ["scatter_nd_test.cc"],
+ deps = [
+ ":builtin_ops",
+ ":test_main",
+ ":test_util",
+ "//tensorflow/lite:framework",
+ "//tensorflow/lite/c:c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
name = "topk_v2_test",
size = "small",
srcs = ["topk_v2_test.cc"],
diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h
index 3d137e4..9d3d09e 100644
--- a/tensorflow/lite/kernels/builtin_op_kernels.h
+++ b/tensorflow/lite/kernels/builtin_op_kernels.h
@@ -147,6 +147,7 @@
TfLiteRegistration* Register_WHILE();
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4();
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5();
+TfLiteRegistration* Register_SCATTER_ND();
} // namespace builtin
} // namespace ops
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 2a69faa..78ab91c 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -2187,6 +2187,48 @@
}
}
+template <typename IndicesT, typename UpdatesT>
+inline void ScatterNd(const RuntimeShape& indices_shape,
+ const IndicesT* indices_data,
+ const RuntimeShape& updates_shape,
+ const UpdatesT* updates_data,
+ const RuntimeShape& output_shape, UpdatesT* output_data) {
+ gemmlowp::ScopedProfilingLabel label("ScatterNd");
+
+ int n_slices = 1;
+ int slice_size = 1;
+ const int outer_dims = indices_shape.DimensionsCount() - 1;
+ const int indices_nd = indices_shape.Dims(outer_dims);
+ const int updates_dims = updates_shape.DimensionsCount();
+ for (int i = 0; i < outer_dims; ++i) {
+ n_slices *= indices_shape.Dims(i);
+ }
+ for (int i = outer_dims; i < updates_dims; ++i) {
+ slice_size *= updates_shape.Dims(i);
+ }
+
+ int output_flat_size = output_shape.FlatSize();
+ int remain_flat_size = output_flat_size;
+ std::vector<int> dims_to_count(indices_nd, 0);
+ for (int i = 0; i < indices_nd; ++i) {
+ dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
+ remain_flat_size = dims_to_count[i];
+ }
+
+ memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
+ for (int i = 0; i < n_slices; ++i) {
+ int to_pos = 0;
+ for (int j = 0; j < indices_nd; ++j) {
+ IndicesT idx = indices_data[i * indices_nd + j];
+ TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
+ to_pos += idx * dims_to_count[j];
+ }
+ for (int j = 0; j < slice_size; j++) {
+ output_data[to_pos + j] += updates_data[i * slice_size + j];
+ }
+ }
+}
+
template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 6e38840..8f2c3e4 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -273,6 +273,7 @@
Register_NON_MAX_SUPPRESSION_V4());
AddBuiltin(BuiltinOperator_NON_MAX_SUPPRESSION_V5,
Register_NON_MAX_SUPPRESSION_V5());
+ AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/lite/kernels/scatter_nd.cc b/tensorflow/lite/kernels/scatter_nd.cc
new file mode 100644
index 0000000..6c012ac
--- /dev/null
+++ b/tensorflow/lite/kernels/scatter_nd.cc
@@ -0,0 +1,190 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace scatter_nd {
+constexpr int kIndices = 0;
+constexpr int kUpdates = 1;
+constexpr int kShape = 2;
+constexpr int kOutputTensor = 0;
+
+template <typename IndicesT>
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const TfLiteTensor* shape,
+ TfLiteTensor* output) {
+ const int shape_rank = SizeOfDimension(shape, 0);
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape_rank);
+ const auto* shape_data = GetTensorData<IndicesT>(shape);
+
+ for (int i = 0; i < shape_rank; i++) {
+ output_shape->data[i] = shape_data[i];
+ }
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+template <typename IndicesT>
+TfLiteStatus CheckShapes(TfLiteContext* context, const RuntimeShape& indices,
+ const RuntimeShape& updates,
+ const RuntimeShape& shape_shape,
+ const IndicesT* shape_data) {
+ TF_LITE_ENSURE(context, (indices.DimensionsCount() >= 1) &&
+ (updates.DimensionsCount() >= 1) &&
+ (shape_shape.DimensionsCount() == 1));
+
+ const int outer_dims = indices.DimensionsCount() - 1;
+ for (int i = 0; i < outer_dims; ++i) {
+ TF_LITE_ENSURE_EQ(context, indices.Dims(i), updates.Dims(i));
+ }
+
+ const int ix = indices.Dims(outer_dims);
+ TF_LITE_ENSURE_EQ(context, updates.DimensionsCount() - outer_dims,
+ shape_shape.Dims(0) - ix);
+ for (int i = 0; i + outer_dims < updates.DimensionsCount(); ++i) {
+ TF_LITE_ENSURE_EQ(context, updates.Dims(i + outer_dims),
+ shape_data[ix + i]);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* indices = GetInput(context, node, kIndices);
+ const TfLiteTensor* updates = GetInput(context, node, kUpdates);
+ const TfLiteTensor* shape = GetInput(context, node, kShape);
+
+ switch (updates->type) {
+ case kTfLiteFloat32:
+ case kTfLiteUInt8:
+ case kTfLiteInt8:
+ case kTfLiteInt64:
+ case kTfLiteInt32:
+ break;
+ default:
+ context->ReportError(
+ context, "Updates of type '%s' are not supported by scatter_nd.",
+ TfLiteTypeGetName(updates->type));
+ return kTfLiteError;
+ }
+ if (indices->type != shape->type) {
+ context->ReportError(context, "Indices and shape must have the same type.");
+ return kTfLiteError;
+ }
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ output->type = updates->type;
+
+ if (IsConstantTensor(shape)) {
+ switch (indices->type) {
+ case kTfLiteInt32:
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckShapes<int32_t>(context, GetTensorShape(indices),
+ GetTensorShape(updates), GetTensorShape(shape),
+ GetTensorData<int32_t>(shape)));
+ return ResizeOutputTensor<int32_t>(context, shape, output);
+ default:
+ context->ReportError(
+ context, "Indices of type '%s' are not supported by scatter_nd.",
+ TfLiteTypeGetName(indices->type));
+ return kTfLiteError;
+ }
+ } else {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+}
+
+template <typename IndicesT, typename UpdatesT>
+TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
+ TfLiteTensor* output) {
+ reference_ops::ScatterNd(
+ GetTensorShape(indices), GetTensorData<IndicesT>(indices),
+ GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
+ GetTensorShape(output), GetTensorData<UpdatesT>(output));
+ return kTfLiteOk;
+}
+
+template <typename IndicesT>
+TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
+ const TfLiteTensor* updates,
+ const TfLiteTensor* shape, TfLiteTensor* output) {
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(
+ context, CheckShapes<IndicesT>(
+ context, GetTensorShape(indices), GetTensorShape(updates),
+ GetTensorShape(shape), GetTensorData<IndicesT>(shape)));
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputTensor<IndicesT>(context, shape, output));
+ }
+
+ switch (updates->type) {
+ case kTfLiteFloat32:
+ return ScatterNd<IndicesT, float>(indices, updates, output);
+ case kTfLiteUInt8:
+ return ScatterNd<IndicesT, uint8_t>(indices, updates, output);
+ case kTfLiteInt8:
+ return ScatterNd<IndicesT, int8_t>(indices, updates, output);
+ case kTfLiteInt32:
+ return ScatterNd<IndicesT, int32_t>(indices, updates, output);
+ case kTfLiteInt64:
+ return ScatterNd<IndicesT, int64_t>(indices, updates, output);
+ default:
+ context->ReportError(
+ context, "Updates of type '%s' are not supported by scatter_nd.",
+ TfLiteTypeGetName(updates->type));
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* indices = GetInput(context, node, kIndices);
+ const TfLiteTensor* updates = GetInput(context, node, kUpdates);
+ const TfLiteTensor* shape = GetInput(context, node, kShape);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (indices->type) {
+ case kTfLiteInt32:
+ return EvalScatterNd<int32_t>(context, indices, updates, shape, output);
+ default:
+ context->ReportError(
+ context, "Indices of type '%s' are not supported by scatter_nd.",
+ TfLiteTypeGetName(indices->type));
+ return kTfLiteError;
+ }
+}
+
+} // namespace scatter_nd
+
+TfLiteRegistration* Register_SCATTER_ND() {
+ static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
+ scatter_nd::Prepare, scatter_nd::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/scatter_nd_test.cc b/tensorflow/lite/kernels/scatter_nd_test.cc
new file mode 100644
index 0000000..e25ba9b
--- /dev/null
+++ b/tensorflow/lite/kernels/scatter_nd_test.cc
@@ -0,0 +1,349 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ScatterNdOpModel : public SingleOpModel {
+ public:
+ ScatterNdOpModel(const TensorData& indices, const TensorData& updates,
+ const TensorData& shape) {
+ indices_ = AddInput(indices);
+ updates_ = AddInput(updates);
+ shape_ = AddInput(shape);
+ output_ = AddOutput(updates.type);
+ SetBuiltinOp(BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions,
+ CreateScatterNdOptions(builder_).Union());
+ BuildInterpreter(
+ {GetShape(indices_), GetShape(updates_), GetShape(shape_)});
+ }
+
+ template <typename T>
+ void SetIndices(std::initializer_list<T> data) {
+ PopulateTensor<T>(indices_, data);
+ }
+
+ template <typename T>
+ void SetUpdates(std::initializer_list<T> data) {
+ PopulateTensor<T>(updates_, data);
+ }
+
+ template <typename T>
+ void SetShape(std::initializer_list<T> data) {
+ PopulateTensor<T>(shape_, data);
+ }
+
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int indices_;
+ int updates_;
+ int shape_;
+ int output_;
+};
+
+TEST(ScatterNdOpTest, ScatterElementIntoVector) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4}},
+ {TensorType_INT32, {1}});
+ m.SetIndices<int32_t>({4, 3, 1, 7});
+ m.SetUpdates<float>({9, 10, 11, 12});
+ m.SetShape<int32_t>({8});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({0, 11, 0, 10, 9, 0, 0, 12}));
+}
+
+TEST(ScatterNdOpTest, ScatterMatrixIntoRank3Tensor) {
+ ScatterNdOpModel m({TensorType_INT32, {2, 1}},
+ {TensorType_FLOAT32, {2, 4, 4}}, {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({0, 2});
+ m.SetUpdates<float>({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
+ 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8});
+ m.SetShape<int32_t>({4, 4, 4});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 4, 4}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ScatterNdOpTest, ScatterVectorIntoMatrix) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4, 4}},
+ {TensorType_INT32, {2}});
+ m.SetIndices<int32_t>({/*0*/ 9, /*1*/ 8, /*2*/ 0, /*3*/ 1});
+ m.SetUpdates<float>({/*0*/ 1, 2, 3, 4,
+ /*1*/ 5, 6, 7, 8,
+ /*2*/ 9, 10, 11, 12,
+ /*3*/ 13, 14, 15, 16});
+ m.SetShape<int32_t>({10, 4});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({10, 4}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({/*0*/ 9, 10, 11, 12,
+ /*1*/ 13, 14, 15, 16,
+ /*2*/ 0, 0, 0, 0,
+ /*3*/ 0, 0, 0, 0,
+ /*4*/ 0, 0, 0, 0,
+ /*5*/ 0, 0, 0, 0,
+ /*6*/ 0, 0, 0, 0,
+ /*7*/ 0, 0, 0, 0,
+ /*8*/ 5, 6, 7, 8,
+ /*9*/ 1, 2, 3, 4}));
+}
+
+TEST(ScatterNdOpTest, ScatterMatricesIntoRank4Tensor) {
+ ScatterNdOpModel m({TensorType_INT32, {2, 2, 2}},
+ {TensorType_FLOAT32, {2, 2, 2, 2}},
+ {TensorType_INT32, {4}});
+ m.SetIndices<int32_t>(
+ {/*0,0*/ 1, 1, /*0,1*/ 0, 1, /*1,0*/ 0, 0, /*1,1*/ 1, 0});
+ m.SetUpdates<float>({/*0,0*/ 1, 2, 3, 4, /*0,1*/ 5, 6, 7, 8,
+ /*1,0*/ 9, 10, 11, 12, /*1,1*/ 13, 14, 15, 16});
+ m.SetShape<int32_t>({2, 2, 2, 2});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({/*0, 0*/ 9, 10, 11, 12,
+ /*0, 1*/ 5, 6, 7, 8,
+ /*1, 0*/ 13, 14, 15, 16,
+ /*1, 1*/ 1, 2, 3, 4}));
+}
+
+TEST(ScatterNdOpTest, ScatterVectorIntoRank4Tensor) {
+ ScatterNdOpModel m({TensorType_INT32, {2, 2, 3}},
+ {TensorType_FLOAT32, {2, 2, 5}}, {TensorType_INT32, {4}});
+ m.SetIndices<int32_t>(
+ {/*0,0*/ 2, 2, 2, /*0,1*/ 1, 0, 1, /*1,0*/ 0, 2, 0, /*1,0*/ 2, 2, 0});
+ m.SetUpdates<float>(
+ {/*0,0*/ 1, 2, 3, 4, 5, /*0,1*/ 6, 7, 8, 9, 10,
+ /*1,0*/ 11, 12, 13, 14, 15, /*1,1*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({3, 3, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3, 5}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({
+ /*0, 0, 0*/ 0, 0, 0, 0, 0,
+ /*0, 0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 0, 2*/ 0, 0, 0, 0, 0,
+ /*0, 1, 0*/ 0, 0, 0, 0, 0,
+ /*0, 1, 1*/ 0, 0, 0, 0, 0,
+ /*0, 1, 2*/ 0, 0, 0, 0, 0,
+ /*0, 2, 0*/ 11, 12, 13, 14, 15,
+ /*0, 2, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2, 2*/ 0, 0, 0, 0, 0,
+ /*1, 0, 0*/ 0, 0, 0, 0, 0,
+ /*1, 0, 1*/ 6, 7, 8, 9, 10,
+ /*1, 0, 2*/ 0, 0, 0, 0, 0,
+ /*1, 1, 0*/ 0, 0, 0, 0, 0,
+ /*1, 1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 1, 2*/ 0, 0, 0, 0, 0,
+ /*1, 2, 0*/ 0, 0, 0, 0, 0,
+ /*1, 2, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2, 2*/ 0, 0, 0, 0, 0,
+ /*2, 0, 0*/ 0, 0, 0, 0, 0,
+ /*2, 0, 1*/ 0, 0, 0, 0, 0,
+ /*2, 0, 2*/ 0, 0, 0, 0, 0,
+ /*2, 1, 0*/ 0, 0, 0, 0, 0,
+ /*2, 1, 1*/ 0, 0, 0, 0, 0,
+ /*2, 1, 2*/ 0, 0, 0, 0, 0,
+ /*2, 2, 0*/ 16, 17, 18, 19, 20,
+ /*2, 2, 1*/ 0, 0, 0, 0, 0,
+ /*2, 2, 2*/ 1, 2, 3, 4, 5,
+ }));
+}
+
+TEST(ScatterNdOpTest, ScatterVectorIntoRank3Tensor) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<float>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+}
+
+TEST(ScatterNdOpTest, OverlappedIndicesSummed) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 1, 0, /*1*/ 0, 2, /*2*/ 0, 2, /*3*/ 1, 0});
+ m.SetUpdates<float>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 17, 19, 21, 23, 25,
+ /*1, 0*/ 17, 19, 21, 23, 25,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 0, 0, 0, 0, 0}));
+}
+
+TEST(ScatterNdOpTest, Int32IndicesUint8Updates) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_UINT8, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<uint8_t>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+}
+
+TEST(ScatterNdOpTest, Int32IndicesInt8Updates) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT8, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<int8_t>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<int8_t>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+}
+
+TEST(ScatterNdOpTest, Int32IndicesInt32Updates) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT32, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<int32_t>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<int32_t>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+}
+
+TEST(ScatterNdOpTest, Int32IndicesInt64Updates) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<int64_t>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<int64_t>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+}
+
+TEST(ScatterNdOpTest, DynamicShape) {
+ ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
+ {TensorType_INT32, {3}});
+ m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
+ m.SetUpdates<int64_t>(
+ {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
+ /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
+ m.SetShape<int32_t>({2, 3, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
+ EXPECT_THAT(m.GetOutput<int64_t>(),
+ ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 11, 12, 13, 14, 15,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20}));
+
+ m.SetIndices<int32_t>({/*0*/ 2, 3, /*1*/ 1, 0, /*2*/ 2, 0, /*3*/ 1, 2});
+ m.SetShape<int32_t>({3, 4, 5});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 4, 5}));
+ EXPECT_THAT(m.GetOutput<int64_t>(),
+ ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
+ /*0, 1*/ 0, 0, 0, 0, 0,
+ /*0, 2*/ 0, 0, 0, 0, 0,
+ /*0, 3*/ 0, 0, 0, 0, 0,
+ /*1, 0*/ 6, 7, 8, 9, 10,
+ /*1, 1*/ 0, 0, 0, 0, 0,
+ /*1, 2*/ 16, 17, 18, 19, 20,
+ /*1, 3*/ 0, 0, 0, 0, 0,
+ /*2, 0*/ 11, 12, 13, 14, 15,
+ /*2, 1*/ 0, 0, 0, 0, 0,
+ /*2, 2*/ 0, 0, 0, 0, 0,
+ /*2, 3*/ 1, 2, 3, 4, 5}));
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index 2a7c222..dc7aab1 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -235,6 +235,7 @@
WHILE = 119,
NON_MAX_SUPPRESSION_V4 = 120,
NON_MAX_SUPPRESSION_V5 = 121,
+ SCATTER_ND = 122
}
// Options for the builtin operators.
@@ -334,7 +335,8 @@
WhileOptions,
DepthToSpaceOptions,
NonMaxSuppressionV4Options,
- NonMaxSuppressionV5Options
+ NonMaxSuppressionV5Options,
+ ScatterNdOptions
}
enum Padding : byte { SAME, VALID }
@@ -812,6 +814,9 @@
table NonMaxSuppressionV5Options {
}
+table ScatterNdOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index fa249d3..ea2f1cc 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -319,6 +319,9 @@
struct NonMaxSuppressionV5Options;
struct NonMaxSuppressionV5OptionsT;
+struct ScatterNdOptions;
+struct ScatterNdOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -597,11 +600,12 @@
BuiltinOperator_WHILE = 119,
BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120,
BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121,
+ BuiltinOperator_SCATTER_ND = 122,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_NON_MAX_SUPPRESSION_V5
+ BuiltinOperator_MAX = BuiltinOperator_SCATTER_ND
};
-inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[123] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -724,7 +728,8 @@
BuiltinOperator_IF,
BuiltinOperator_WHILE,
BuiltinOperator_NON_MAX_SUPPRESSION_V4,
- BuiltinOperator_NON_MAX_SUPPRESSION_V5
+ BuiltinOperator_NON_MAX_SUPPRESSION_V5,
+ BuiltinOperator_SCATTER_ND
};
return values;
}
@@ -853,13 +858,14 @@
"WHILE",
"NON_MAX_SUPPRESSION_V4",
"NON_MAX_SUPPRESSION_V5",
+ "SCATTER_ND",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
- if (e < BuiltinOperator_ADD || e > BuiltinOperator_NON_MAX_SUPPRESSION_V5) return "";
+ if (e < BuiltinOperator_ADD || e > BuiltinOperator_SCATTER_ND) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@@ -962,11 +968,12 @@
BuiltinOptions_DepthToSpaceOptions = 94,
BuiltinOptions_NonMaxSuppressionV4Options = 95,
BuiltinOptions_NonMaxSuppressionV5Options = 96,
+ BuiltinOptions_ScatterNdOptions = 97,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_NonMaxSuppressionV5Options
+ BuiltinOptions_MAX = BuiltinOptions_ScatterNdOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[98] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -1064,7 +1071,8 @@
BuiltinOptions_WhileOptions,
BuiltinOptions_DepthToSpaceOptions,
BuiltinOptions_NonMaxSuppressionV4Options,
- BuiltinOptions_NonMaxSuppressionV5Options
+ BuiltinOptions_NonMaxSuppressionV5Options,
+ BuiltinOptions_ScatterNdOptions
};
return values;
}
@@ -1168,13 +1176,14 @@
"DepthToSpaceOptions",
"NonMaxSuppressionV4Options",
"NonMaxSuppressionV5Options",
+ "ScatterNdOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
- if (e < BuiltinOptions_NONE || e > BuiltinOptions_NonMaxSuppressionV5Options) return "";
+ if (e < BuiltinOptions_NONE || e > BuiltinOptions_ScatterNdOptions) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@@ -1567,6 +1576,10 @@
static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options;
};
+template<> struct BuiltinOptionsTraits<ScatterNdOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -2367,6 +2380,14 @@
return type == BuiltinOptions_NonMaxSuppressionV5Options ?
reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value) : nullptr;
}
+ ScatterNdOptionsT *AsScatterNdOptions() {
+ return type == BuiltinOptions_ScatterNdOptions ?
+ reinterpret_cast<ScatterNdOptionsT *>(value) : nullptr;
+ }
+ const ScatterNdOptionsT *AsScatterNdOptions() const {
+ return type == BuiltinOptions_ScatterNdOptions ?
+ reinterpret_cast<const ScatterNdOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -8226,6 +8247,46 @@
flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ScatterNdOptionsT : public flatbuffers::NativeTable {
+ typedef ScatterNdOptions TableType;
+ ScatterNdOptionsT() {
+ }
+};
+
+struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ScatterNdOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ ScatterNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ScatterNdOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ScatterNdOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit ScatterNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ScatterNdOptionsBuilder &operator=(const ScatterNdOptionsBuilder &);
+ flatbuffers::Offset<ScatterNdOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ScatterNdOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ ScatterNdOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -8650,6 +8711,9 @@
const NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const {
return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV5Options ? static_cast<const NonMaxSuppressionV5Options *>(builtin_options()) : nullptr;
}
+ const ScatterNdOptions *builtin_options_as_ScatterNdOptions() const {
+ return builtin_options_type() == BuiltinOptions_ScatterNdOptions ? static_cast<const ScatterNdOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -9070,6 +9134,10 @@
return builtin_options_as_NonMaxSuppressionV5Options();
}
+template<> inline const ScatterNdOptions *Operator::builtin_options_as<ScatterNdOptions>() const {
+ return builtin_options_as_ScatterNdOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -12225,6 +12293,29 @@
_fbb);
}
+inline ScatterNdOptionsT *ScatterNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ScatterNdOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ScatterNdOptions::UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<ScatterNdOptions> ScatterNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateScatterNdOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ScatterNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateScatterNdOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -12902,6 +12993,10 @@
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_ScatterNdOptions: {
+ auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -13304,6 +13399,10 @@
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_ScatterNdOptions: {
+ auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -13694,6 +13793,10 @@
auto ptr = reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value);
return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_ScatterNdOptions: {
+ auto ptr = reinterpret_cast<const ScatterNdOptionsT *>(value);
+ return CreateScatterNdOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -14084,6 +14187,10 @@
value = new NonMaxSuppressionV5OptionsT(*reinterpret_cast<NonMaxSuppressionV5OptionsT *>(u.value));
break;
}
+ case BuiltinOptions_ScatterNdOptions: {
+ value = new ScatterNdOptionsT(*reinterpret_cast<ScatterNdOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -14571,6 +14678,11 @@
delete ptr;
break;
}
+ case BuiltinOptions_ScatterNdOptions: {
+ auto ptr = reinterpret_cast<ScatterNdOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;