blob: 54b34e4e33196837e94e231bfcf6535e2c01b90b [file] [log] [blame]
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* Sentencepiece tflite detokenizer implementation.
*/
#include <algorithm>
#include <iterator>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h"
namespace tflite {
namespace ops {
namespace custom {
namespace sentencepiece {
namespace detokenizer {
constexpr int kOutputValuesInd = 0;
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(mgubin): Add checks for input and output tensors.
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
SetTensorToDynamic(&output_values);
// TODO(mgubin): Check input types.
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor& model_tensor =
context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]];
const auto model_buffer_data = model_tensor.data.data;
const TfLiteTensor& input_encoded =
context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]];
const int32_t* input_encoded_data = input_encoded.data.i32;
const TfLiteTensor& input_splits =
context->tensors[node->inputs->data[tensorflow::ops::kInputSplits]];
const int num_of_sentences = NumElements(input_splits.dims) - 1;
const int32_t* input_splits_data = input_splits.data.i32;
DynamicBuffer buf;
std::vector<int> codes_for_split;
int input_offset = 0;
for (int i = 0; i < num_of_sentences; i++) {
// Create a vector of int32 from input according to spans.
const int split_size = input_splits_data[i + 1] - input_splits_data[i];
codes_for_split.clear();
std::copy(input_encoded_data + input_offset,
input_encoded_data + input_offset + split_size,
std::back_inserter(codes_for_split));
const auto res = DecodeString(codes_for_split, model_buffer_data);
TF_LITE_ENSURE_MSG(context, res.type == DecoderResultType::SUCCESS,
"Sentencepiece decoding failed");
buf.AddString(res.decoded.data(), res.decoded.length());
input_offset += split_size;
}
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
buf.WriteToTensor(&output_values, nullptr);
return kTfLiteOk;
}
} // namespace detokenizer
} // namespace sentencepiece
TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER() {
static TfLiteRegistration r = {
sentencepiece::detokenizer::Initialize, sentencepiece::detokenizer::Free,
sentencepiece::detokenizer::Prepare, sentencepiece::detokenizer::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite