| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "utils/tflite/dist_diversification.h" |
| |
| #include <algorithm> |
| #include "tensorflow/lite/context.h" |
| #include "tensorflow/lite/kernels/kernel_util.h" |
| #include "tensorflow/lite/model.h" |
| |
| namespace libtextclassifier3 { |
| namespace { |
| |
| // Returns a vector of row indices in a distance matrix. |
| // Indices are increasing and the distance of every selected index to others |
| // is larger than `min_distance`. |
| template <typename DistanceMatrixType> |
| std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix, |
| const int matrix_size, |
| const float min_distance, |
| const int max_num_results) { |
| std::vector<int> result{0}; |
| result.reserve(max_num_results); |
| int index = 1; |
| while (result.size() < max_num_results && index < matrix_size) { |
| for (; index < matrix_size; ++index) { |
| bool too_close = false; |
| for (const int selected_index : result) { |
| if (distance_matrix(index, selected_index) < min_distance) { |
| too_close = true; |
| break; |
| } |
| } |
| if (!too_close) { |
| result.push_back(index); |
| ++index; |
| break; |
| } |
| } |
| } |
| return result; |
| } |
| |
| // Input parameters for the op. |
| enum DistDiversificationInputs { |
| DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0, |
| DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1, |
| DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2 |
| }; |
| |
| // Output parameters for the op. |
| enum DistDiversificationOutputs { |
| DIST_DIVERSIFICATION_OUTPUT_INDICES = 0, |
| DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1, |
| }; |
| |
| TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) { |
| TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); |
| int index = 0; |
| for (const int size : sizes) { |
| array_size->data[index++] = size; |
| } |
| return array_size; |
| } |
| |
| TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) { |
| const TfLiteTensor& num_results = |
| context |
| ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]; |
| TfLiteTensor& output_indices = |
| context |
| ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]]; |
| return context->ResizeTensor(context, &output_indices, |
| CreateSizeArray({num_results.data.i32[0]})); |
| } |
| |
| TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
| const TfLiteTensor& num_results = |
| context |
| ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]; |
| if (tflite::IsConstantTensor(&num_results)) { |
| TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node)); |
| } else { |
| TfLiteTensor& output_indices = |
| context |
| ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]]; |
| tflite::SetTensorToDynamic(&output_indices); |
| } |
| TfLiteTensor& output_length = |
| context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]]; |
| TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length, |
| CreateSizeArray({1}))); |
| return kTfLiteOk; |
| } |
| |
| TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
| TfLiteTensor& output_indices = |
| context |
| ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]]; |
| if (tflite::IsDynamicTensor(&output_indices)) { |
| TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node)); |
| } |
| const TfLiteTensor& distance_matrix = |
| context->tensors[node->inputs |
| ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]]; |
| const int distance_matrix_dim = distance_matrix.dims->data[0]; |
| const float min_distance = |
| context |
| ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]] |
| .data.f[0]; |
| const int num_results = |
| context |
| ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]] |
| .data.i32[0]; |
| const auto indices = DiversifyByDistance( |
| [&](int row, int col) { |
| return distance_matrix.data.f[row * distance_matrix_dim + col]; |
| }, |
| distance_matrix_dim, min_distance, num_results); |
| std::copy(indices.begin(), indices.end(), output_indices.data.i32); |
| std::fill_n(output_indices.data.i32 + indices.size(), |
| num_results - indices.size(), -1); |
| TfLiteTensor& output_length = |
| context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]]; |
| *output_length.data.i32 = indices.size(); |
| return kTfLiteOk; |
| } |
| |
| } // namespace |
| } // namespace libtextclassifier3 |
| |
| namespace tflite { |
| namespace ops { |
| namespace custom { |
| TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() { |
| static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare, |
| libtextclassifier3::Eval}; |
| return &r; |
| } |
| } // namespace custom |
| } // namespace ops |
| } // namespace tflite |