blob: b6a1a67d74ba23c1d5b0ae40e2df3854b36b96fc [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for tensorflow_lite_support.custom_ops.kernel.whitespace_tokenizer."""
import os
import sys
import timeit
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text
# pylint: disable=g-direct-tensorflow-import
from tensorflow.lite.python import interpreter as interpreter_wrapper
from tensorflow.python.platform import resource_loader
# Force loaded shared object symbols to be globally visible. This is needed so
# that the interpreter_wrapper, in one .so file, can see the op resolver
# in a different .so file. Note that this may already be set by default.
# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import
if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL)
from tensorflow_lite_support.custom_ops.kernel import _pywrap_whitespace_tokenizer_op_resolver
TEST_CASES = [
['this is a test'],
['extra spaces in here'],
['a four token sentence', 'a five token sentence thing.'],
[['a multi dimensional test case', 'a b c d', 'e f g'],
['h i j', 'k l m 2 3', 'n o p'], ['q r s 0 1', 't u v', 'w x y z']],
]
INVOKES_FOR_SINGLE_OP_BENCHMARK = 1000
INVOKES_FOR_FLEX_DELEGATE_BENCHMARK = 10
@tf.function
def _call_whitespace_tokenizer_to_tensor(test_case):
tokenizer = tf_text.WhitespaceTokenizer()
return tokenizer.tokenize(test_case).to_tensor()
@tf.function
def _call_whitespace_tokenizer_to_ragged(test_case):
tokenizer = tf_text.WhitespaceTokenizer()
return tokenizer.tokenize(test_case)
class WhitespaceTokenizerTest(parameterized.TestCase):
@parameterized.parameters([t] for t in TEST_CASES)
def testToTensorEquivalence(self, test_case):
tf_output = _call_whitespace_tokenizer_to_tensor(test_case)
model_filename = resource_loader.get_path_to_datafile(
'testdata/whitespace_tokenizer_to_tensor.tflite')
with open(model_filename, 'rb') as file:
model = file.read()
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_content=model,
custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
np_test_case = np.array(test_case, dtype=np.str)
interpreter.resize_tensor_input(0, np_test_case.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
np_test_case)
interpreter.invoke()
tflite_output = interpreter.get_tensor(
interpreter.get_output_details()[0]['index'])
self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist())
@parameterized.parameters([t] for t in TEST_CASES)
def testToRaggedEquivalence(self, test_case):
tf_output = _call_whitespace_tokenizer_to_ragged(test_case)
np_test_case = np.array(test_case, dtype=np.str)
rank = len(np_test_case.shape)
model_filename = resource_loader.get_path_to_datafile(
'testdata/whitespace_tokenizer_to_ragged_{}d_input.tflite'.format(rank))
with open(model_filename, 'rb') as file:
model = file.read()
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_content=model,
custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
interpreter.resize_tensor_input(0, np_test_case.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
np_test_case)
interpreter.invoke()
# Traverse the nested row_splits/values of the ragged tensor.
for i in range(rank):
tflite_output_cur_row_splits = interpreter.get_tensor(
interpreter.get_output_details()[1 + i]['index'])
self.assertEqual(tf_output.row_splits.numpy().tolist(),
tflite_output_cur_row_splits.tolist())
tf_output = tf_output.values
tflite_output_values = interpreter.get_tensor(
interpreter.get_output_details()[0]['index'])
self.assertEqual(tf_output.numpy().tolist(), tflite_output_values.tolist())
def testSingleOpLatency(self):
model_filename = resource_loader.get_path_to_datafile(
'testdata/whitespace_tokenizer_to_tensor.tflite')
with open(model_filename, 'rb') as file:
model = file.read()
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_content=model,
custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
latency = 0.0
for test_case in TEST_CASES:
np_test_case = np.array(test_case, dtype=np.str)
interpreter.resize_tensor_input(0, np_test_case.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
np_test_case)
start_time = timeit.default_timer()
for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK):
interpreter.invoke()
latency = latency + timeit.default_timer() - start_time
latency = latency / (INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES))
logging.info('Latency: %fms', latency * 1000.0)
def testFlexDelegateLatency(self):
model_filename = resource_loader.get_path_to_datafile(
'testdata/whitespace_tokenizer_flex_delegate.tflite')
with open(model_filename, 'rb') as file:
model = file.read()
interpreter = interpreter_wrapper.Interpreter(model_content=model)
latency = 0.0
for test_case in TEST_CASES:
np_test_case = np.array(test_case, dtype=np.str)
interpreter.resize_tensor_input(0, np_test_case.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
np_test_case)
start_time = timeit.default_timer()
for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK):
interpreter.invoke()
latency = latency + timeit.default_timer() - start_time
latency = latency / (INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES))
logging.info('Latency: %fms', latency * 1000.0)
if __name__ == '__main__':
tf.test.main()