blob: bd4b5a17f75c940741eead11d38639315c66834a [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h"
namespace tensorflow {
namespace ops {
REGISTER_OP("TFSentencepieceDetokenizeOp")
.Input("sp_model: uint8")
.Input("input_values: int32")
.Input("input_splits: Tsplits")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.Output("output: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
shape_inference::DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim));
c->set_output(0, c->Vector(dim));
return Status::OK();
});
template <typename Tsplits>
class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel {
public:
explicit TFSentencepieceDetokenizerOp(tensorflow::OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(tensorflow::OpKernelContext* ctx) override {
const auto& model_tensor = ctx->input(kSPModelIndex);
const auto& input_values_tensor = ctx->input(kInputIndex);
const auto input_values_flat =
input_values_tensor.flat<tensorflow::int32>();
const auto& input_splits_tensor = ctx->input(kInputSplits);
const auto input_splits_flat = input_splits_tensor.flat<Tsplits>();
const int num_of_sentences = input_splits_flat.size() - 1;
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, {num_of_sentences}, &output_tensor));
auto output_flat = output_tensor->flat<tensorflow::tstring>();
std::vector<int> codes_for_split;
int input_offset = 0;
for (int i = 0; i < num_of_sentences; i++) {
// Create a vector of int32 from input according to spans.
const int split_size = input_splits_flat(i + 1) - input_splits_flat(i);
codes_for_split.clear();
codes_for_split.reserve(split_size);
for (int j = 0; j < split_size; ++j) {
codes_for_split.push_back(input_values_flat(input_offset++));
}
const auto res = tflite::ops::custom::sentencepiece::DecodeString(
codes_for_split, model_tensor.data());
OP_REQUIRES(
ctx,
res.type ==
tflite::ops::custom::sentencepiece::DecoderResultType::SUCCESS,
tensorflow::Status(tensorflow::error::INTERNAL,
"Sentencepiece conversion failed"));
output_flat(i) = res.decoded;
}
}
};
} // namespace ops
} // namespace tensorflow
REGISTER_KERNEL_BUILDER(
Name("TFSentencepieceDetokenizeOp")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<tensorflow::int32>("Tsplits"),
tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int32>);
REGISTER_KERNEL_BUILDER(
Name("TFSentencepieceDetokenizeOp")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<tensorflow::int64>("Tsplits"),
tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int64>);