blob: edfb887b3c5d393cc9354289a7f28e1b316dd9de [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lite.py functionality related to select TF op usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import lite
from tensorflow.lite.python import test_util as tflite_test_util
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.python.testdata import double_op
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.tracking import tracking
class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
def testFlexMode(self, enable_mlir):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.experimental_new_converter = enable_mlir
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check the model works with TensorFlow ops.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testFlexWithAutomaticPassThrough(self):
# Create a graph that has one L2Loss op.
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[4], dtype=dtypes.float32, name='input')
out_tensor = nn_ops.l2_loss(in_tensor)
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter._experimental_allow_all_select_tf_ops = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
self.assertIn('FlexL2Loss', tflite_test_util.get_ops_list(tflite_model))
def testDeprecatedFlags(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS])
# Ensure `target_ops` is set to the correct value after flag deprecation.
self.assertEqual(converter.target_ops, set([lite.OpsSet.SELECT_TF_OPS]))
self.assertEqual(converter.target_spec.supported_ops,
set([lite.OpsSet.SELECT_TF_OPS]))
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check the model works with TensorFlow ops.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
@test_util.run_v2_only
def testFloat(self, enable_mlir):
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
root)
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.experimental_new_converter = enable_mlir
tflite_model = converter.convert()
# Check the model works with TensorFlow ops.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([4.0], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([24.0], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _createGraphWithCustomOp(self, opname='CustomAdd'):
custom_opdefs_str = (
'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
'\'Output\' type: DT_FLOAT}')
# Create a graph that has one add op.
new_graph = graph_pb2.GraphDef()
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
out_tensor = in_tensor + in_tensor
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
new_graph.CopyFrom(sess.graph_def)
# Rename Add op name to opname.
for node in new_graph.node:
if node.op.startswith('Add'):
node.op = opname
del node.attr['T']
# Register custom op defs to import modified graph def.
register_custom_opdefs([custom_opdefs_str])
return (new_graph, inputs, outputs)
def testFlexWithCustomOp(self):
new_graph, inputs, outputs = self._createGraphWithCustomOp(
opname='CustomAdd4')
# Import to load the custom opdef.
saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
with ops.Graph().as_default():
with session.Session() as sess:
import_graph_def(new_graph, name='')
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4']
tflite_model = converter.convert()
self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model))
def testFlexWithDoubleOp(self):
# Create a graph that has one double op.
saved_model_dir = os.path.join(self.get_temp_dir(), 'model2')
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[1, 4], dtype=dtypes.int32, name='input')
out_tensor = double_op.double(in_tensor)
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.target_spec.experimental_select_user_tf_ops = ['Double']
tflite_model = converter.convert()
self.assertTrue(tflite_model)
self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model))
# Check the model works with TensorFlow ops.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
if __name__ == '__main__':
test.main()