blob: 0d2cd0ee5ddadca640f3ed9cc4c3508094a95c6b [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for auto_quant."""
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.lite.tools.optimize.auto_quant import auto_quant
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class AutoQuantTest(test_util.TensorFlowTestCase):
def test_quant_single_layers(self):
class Multiplier(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def multiply(self, x):
return x * x
model = Multiplier()
model_path = tempfile.mkdtemp()
tf.saved_model.save(model, model_path)
def representative_dataset():
for i in range(5):
yield [np.array([i], dtype=np.float32)]
quanted_models = auto_quant.quant_single_layers(model_path,
representative_dataset)
for flatbuffer, _ in quanted_models:
model = flatbuffer_utils.convert_bytearray_to_object(flatbuffer)
quant_index = -1
for i, opcode in enumerate(model.operatorCodes):
if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE:
quant_index = i
break
self.assertNotEqual(quant_index, -1)
quant_cnt = 0
for graph in model.subgraphs:
for op in graph.operators:
if op.opcodeIndex == quant_index:
quant_cnt += 1
self.assertEqual(quant_cnt, 1)
if __name__ == '__main__':
test.main()