| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| # Lint as: python3 |
| """Python class that implements Sentencepiece tokenizer. |
| |
| It follows TF.text designers design. |
| |
| """ |
| import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import |
| from tensorflow.python.ops.ragged import ragged_tensor # pylint: disable=g-direct-tensorflow-import |
| from tensorflow.python.framework import load_library |
| from tensorflow.python.platform import resource_loader |
| gen_sentencepiece_detokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_detokenizer_op.so')) |
| from tensorflow.python.framework import load_library |
| from tensorflow.python.platform import resource_loader |
| gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so')) |
| from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter |
| |
| |
| class SentencepieceTokenizer: |
| """Sentencepiece tokenizer with tf.text interface.""" |
| |
| def __init__(self, model, reverse=False, add_bos=False, add_eos=False): |
| converted_model = model_converter.convert_sentencepiece_model(model) |
| converted_model_detokenizer = model_converter.convert_sentencepiece_model_for_decoder( |
| model) |
| # Use uint8 tensor as a buffer for the model to avoid any possible changes, |
| # for example truncation by '\0'. |
| self._converted_model = tf.constant(list(converted_model), dtype=tf.uint8) |
| self._converted_model_detokenizer = tf.constant( |
| list(converted_model_detokenizer), dtype=tf.uint8) |
| self._vocab_size = model_converter.get_vocabulary_size(converted_model) |
| self._reverse = reverse |
| self._add_bos = add_bos |
| self._add_eos = add_eos |
| |
| def tokenize(self, inputs): |
| """The main tokenization function.""" |
| input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs) |
| if input_tensor.shape.ndims is None: |
| raise ValueError("Rank of input_tensor must be statically known.") |
| if ragged_tensor.is_ragged(input_tensor): |
| # Ensure that input has row_split_dtype is int32 |
| input_tensor = input_tensor.with_row_splits_dtype(tf.int32) |
| # Recursively process the values of the ragged tensor. |
| tokens = self.tokenize(input_tensor.flat_values) |
| return input_tensor.with_flat_values(tokens) |
| else: |
| if input_tensor.shape.ndims > 1: |
| # Convert the input tensor to ragged and process it. |
| return self.tokenize( |
| tf.RaggedTensor.from_tensor( |
| input_tensor, row_splits_dtype=tf.int32)) |
| elif input_tensor.shape.ndims == 0: |
| tokens = self.tokenize(tf.stack([input_tensor])) |
| return tokens.values |
| else: |
| # Our rank 1 tensor is the correct shape, so we can process it as |
| # normal. |
| (output_values, row_splits) = ( |
| gen_sentencepiece_tokenizer_op.tf_sentencepiece_tokenize_op( |
| self._converted_model, input_tensor, 0, 0, self._add_bos, |
| self._add_eos, self._reverse)) |
| tokens = tf.RaggedTensor.from_nested_row_splits( |
| flat_values=output_values, |
| nested_row_splits=[row_splits], |
| validate=False) |
| return tokens |
| |
| def detokenize(self, input): # pylint: disable=redefined-builtin |
| """Detokenizes tokens into preprocessed text. |
| |
| Args: |
| input: A `RaggedTensor` or `Tensor` with int32 encoded text with rank >= |
| 1. |
| |
| Returns: |
| A N-1 dimensional string Tensor or RaggedTensor of the detokenized text. |
| """ |
| input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input) |
| if input_tensor.shape.ndims is None: |
| raise ValueError("Rank of input_tensor must be statically known.") |
| if input_tensor.shape.ndims == 0: |
| raise ValueError("Rank of input_tensor must be at least 1.") |
| if ragged_tensor.is_ragged(input_tensor): |
| if input_tensor.flat_values.shape.ndims > 1: |
| # If the flat_values of our ragged tensor is multi-dimensional, we can |
| # process it separately and our output will have the same nested |
| # splits as our input. |
| tokens = self.detokenize(input_tensor.flat_values) |
| return input_tensor.with_flat_values(tokens) |
| elif input_tensor.ragged_rank > 1: |
| # Recursively process the values of the ragged tensor. |
| tokens = self.detokenize(input_tensor.values) |
| return input_tensor.with_values(tokens) |
| else: |
| return gen_sentencepiece_detokenizer_op.tf_sentencepiece_detokenize_op( |
| self._converted_model_detokenizer, input_tensor.flat_values, |
| input_tensor.row_splits) |
| else: |
| if input_tensor.shape.ndims > 1: |
| # Convert the input tensor to ragged and process it. |
| return self.detokenize( |
| tf.RaggedTensor.from_tensor( |
| input_tensor, row_splits_dtype=tf.int32)) |
| else: |
| tokens = self.detokenize(tf.stack([input_tensor])) |
| return tf.reshape(tokens, []) |
| |
| def vocab_size(self): |
| """Returns size of the vocabulary in Sentencepiece model.""" |
| return self._vocab_size |