blob: 5ac0910675908fc9a6003cd3df77caf9083abbd8 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for analyzer package."""
from tensorflow.lite.python import analyzer
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class ConvertTest(test_util.TensorFlowTestCase):
def testTxt(self):
model_path = resource_loader.get_path_to_datafile(
"testdata/permute_float.tflite")
txt = analyzer.ModelAnalyzer.analyze(model_path, "txt")
self.assertIn("Subgraph#0 main(T#0) -> [T#2]", txt)
self.assertIn("Op#0 FULLY_CONNECTED(T#0, T#1, T#-1) -> [T#2]", txt)
def testHtml(self):
model_path = resource_loader.get_path_to_datafile(
"testdata/permute_float.tflite")
html = analyzer.ModelAnalyzer.analyze(model_path, "html")
self.assertIn("<html>\n<head>", html)
self.assertIn("FULLY_CONNECTED (0)", html)
def testMlir(self):
model_path = resource_loader.get_path_to_datafile(
"testdata/permute_float.tflite")
mlir = analyzer.ModelAnalyzer.analyze(model_path, "mlir")
self.assertIn(
"func @main(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> attributes "
'{tf.entry_function = {inputs = "input", outputs = "output"}}', mlir)
self.assertIn(
'%1 = "tfl.fully_connected"(%arg0, %0, %cst) {fused_activation_function'
' = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : '
"(tensor<1x4xf32>, tensor<4x4xf32>, none) -> tensor<1x4xf32>", mlir)
def testMlirHugeConst(self):
model_path = resource_loader.get_path_to_datafile(
"../testdata/conv_huge_im2col.bin")
mlir = analyzer.ModelAnalyzer.analyze(model_path, "mlir")
self.assertIn(
'%1 = "tfl.pseudo_const"() {value = opaque<"_", "0xDEADBEEF"> : '
"tensor<3x3x3x8xf32>} : () -> tensor<3x3x3x8xf32>", mlir)
if __name__ == "__main__":
test.main()