blob: 53c57d254b541763f6006a6faeac3d60d1343e24 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "c8Cx-rUMVX25"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "I9sUhVL_VZNO"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Y8E0lw5eYWm"
},
"source": [
"# Post-training float16 quantization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGuqeuPSVNo-"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_float16_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BTC1rDAuei_1"
},
"source": [
"## Overview\n",
"\n",
"[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
"converting weights to 16-bit floating point values during model conversion from TensorFlow to TensorFlow Lite's flat buffer format. This results in a 2x reduction in model size. Some harware, like GPUs, can compute natively in this reduced precision arithmetic, realizing a speedup over traditional floating point execution. The Tensorflow Lite GPU delegate can be configured to run in this way. However, a model converted to float16 weights can still run on the CPU without additional modification: the float16 weights are upsampled to float32 prior to the first inference. This permits a significant reduction in model size in exchange for a minimal impacts to latency and accuracy.\n",
"\n",
"In this tutorial, you train an MNIST model from scratch, check its accuracy in TensorFlow, and then convert the model into a Tensorflow Lite flatbuffer\n",
"with float16 quantization. Finally, check the accuracy of the converted model and compare it to the original float32 model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2XsEP17Zelz9"
},
"source": [
"## Build an MNIST model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDqqUIZjZjac"
},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gyqAw1M9lyab"
},
"outputs": [],
"source": [
"import logging\n",
"logging.getLogger(\"tensorflow\").setLevel(logging.DEBUG)\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import numpy as np\n",
"import pathlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c6nb7OPlXs_3"
},
"outputs": [],
"source": [
"tf.float16"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQ6Q0qqKZogR"
},
"source": [
"### Train and export the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hWSAjQWagIHl"
},
"outputs": [],
"source": [
"# Load MNIST dataset\n",
"mnist = keras.datasets.mnist\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
"\n",
"# Normalize the input image so that each pixel value is between 0 to 1.\n",
"train_images = train_images / 255.0\n",
"test_images = test_images / 255.0\n",
"\n",
"# Define the model architecture\n",
"model = keras.Sequential([\n",
" keras.layers.InputLayer(input_shape=(28, 28)),\n",
" keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
" keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),\n",
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
" keras.layers.Flatten(),\n",
" keras.layers.Dense(10)\n",
"])\n",
"\n",
"# Train the digit classification model\n",
"model.compile(optimizer='adam',\n",
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"model.fit(\n",
" train_images,\n",
" train_labels,\n",
" epochs=1,\n",
" validation_data=(test_images, test_labels)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5NMaNZQCkW9X"
},
"source": [
"For the example, you trained the model for just a single epoch, so it only trains to ~96% accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xl8_fzVAZwOh"
},
"source": [
"### Convert to a TensorFlow Lite model\n",
"\n",
"Using the Python [TFLiteConverter](https://www.tensorflow.org/lite/convert/python_api), you can now convert the trained model into a TensorFlow Lite model.\n",
"\n",
"Now load the model using the `TFLiteConverter`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_i8B2nDZmAgQ"
},
"outputs": [],
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"tflite_model = converter.convert()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F2o2ZfF0aiCx"
},
"source": [
"Write it out to a `.tflite` file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vptWZq2xnclo"
},
"outputs": [],
"source": [
"tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
"tflite_models_dir.mkdir(exist_ok=True, parents=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ie9pQaQrn5ue"
},
"outputs": [],
"source": [
"tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
"tflite_model_file.write_bytes(tflite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7BONhYtYocQY"
},
"source": [
"To instead quantize the model to float16 on export, first set the `optimizations` flag to use default optimizations. Then specify that float16 is the supported type on the target platform:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HEZ6ET1AHAS3"
},
"outputs": [],
"source": [
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"converter.target_spec.supported_types = [tf.float16]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xW84iMYjHd9t"
},
"source": [
"Finally, convert the model like usual. Note, by default the converted model will still use float input and outputs for invocation convenience."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yuNfl3CoHNK3"
},
"outputs": [],
"source": [
"tflite_fp16_model = converter.convert()\n",
"tflite_model_fp16_file = tflite_models_dir/\"mnist_model_quant_f16.tflite\"\n",
"tflite_model_fp16_file.write_bytes(tflite_fp16_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhMmUTl4sbkz"
},
"source": [
"Note how the resulting file is approximately `1/2` the size."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JExfcfLDscu4"
},
"outputs": [],
"source": [
"!ls -lh {tflite_models_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L8lQHMp_asCq"
},
"source": [
"## Run the TensorFlow Lite models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-5l6-ciItvX6"
},
"source": [
"Run the TensorFlow Lite model using the Python TensorFlow Lite Interpreter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ap_jE7QRvhPf"
},
"source": [
"### Load the model into the interpreters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jn16Rc23zTss"
},
"outputs": [],
"source": [
"interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
"interpreter.allocate_tensors()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8Pztk1mvNVL"
},
"outputs": [],
"source": [
"interpreter_fp16 = tf.lite.Interpreter(model_path=str(tflite_model_fp16_file))\n",
"interpreter_fp16.allocate_tensors()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2opUt_JTdyEu"
},
"source": [
"### Test the models on one image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AKslvo2kwWac"
},
"outputs": [],
"source": [
"test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)\n",
"\n",
"input_index = interpreter.get_input_details()[0][\"index\"]\n",
"output_index = interpreter.get_output_details()[0][\"index\"]\n",
"\n",
"interpreter.set_tensor(input_index, test_image)\n",
"interpreter.invoke()\n",
"predictions = interpreter.get_tensor(output_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XZClM2vo3_bm"
},
"outputs": [],
"source": [
"import matplotlib.pylab as plt\n",
"\n",
"plt.imshow(test_images[0])\n",
"template = \"True:{true}, predicted:{predict}\"\n",
"_ = plt.title(template.format(true= str(test_labels[0]),\n",
" predict=str(np.argmax(predictions[0]))))\n",
"plt.grid(False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3gwhv4lKbYZ4"
},
"outputs": [],
"source": [
"test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)\n",
"\n",
"input_index = interpreter_fp16.get_input_details()[0][\"index\"]\n",
"output_index = interpreter_fp16.get_output_details()[0][\"index\"]\n",
"\n",
"interpreter_fp16.set_tensor(input_index, test_image)\n",
"interpreter_fp16.invoke()\n",
"predictions = interpreter_fp16.get_tensor(output_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CIH7G_MwbY2x"
},
"outputs": [],
"source": [
"plt.imshow(test_images[0])\n",
"template = \"True:{true}, predicted:{predict}\"\n",
"_ = plt.title(template.format(true= str(test_labels[0]),\n",
" predict=str(np.argmax(predictions[0]))))\n",
"plt.grid(False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LwN7uIdCd8Gw"
},
"source": [
"### Evaluate the models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "05aeAuWjvjPx"
},
"outputs": [],
"source": [
"# A helper function to evaluate the TF Lite model using \"test\" dataset.\n",
"def evaluate_model(interpreter):\n",
" input_index = interpreter.get_input_details()[0][\"index\"]\n",
" output_index = interpreter.get_output_details()[0][\"index\"]\n",
"\n",
" # Run predictions on every image in the \"test\" dataset.\n",
" prediction_digits = []\n",
" for test_image in test_images:\n",
" # Pre-processing: add batch dimension and convert to float32 to match with\n",
" # the model's input data format.\n",
" test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
" interpreter.set_tensor(input_index, test_image)\n",
"\n",
" # Run inference.\n",
" interpreter.invoke()\n",
"\n",
" # Post-processing: remove batch dimension and find the digit with highest\n",
" # probability.\n",
" output = interpreter.tensor(output_index)\n",
" digit = np.argmax(output()[0])\n",
" prediction_digits.append(digit)\n",
"\n",
" # Compare prediction results with ground truth labels to calculate accuracy.\n",
" accurate_count = 0\n",
" for index in range(len(prediction_digits)):\n",
" if prediction_digits[index] == test_labels[index]:\n",
" accurate_count += 1\n",
" accuracy = accurate_count * 1.0 / len(prediction_digits)\n",
"\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T5mWkSbMcU5z"
},
"outputs": [],
"source": [
"print(evaluate_model(interpreter))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Km3cY9ry8ZlG"
},
"source": [
"Repeat the evaluation on the float16 quantized model to obtain:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-9cnwiPp6EGm"
},
"outputs": [],
"source": [
"# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
"# doesn't have super optimized server CPU kernels. For this reason this may be\n",
"# slower than the above float interpreter. But for mobile CPUs, considerable\n",
"# speedup can be observed.\n",
"print(evaluate_model(interpreter_fp16))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7lfxkor8pgv"
},
"source": [
"In this example, you have quantized a model to float16 with no difference in the accuracy.\n",
"\n",
"It's also possible to evaluate the fp16 quantized model on the GPU. To perform all arithmetic with the reduced precision values, be sure to create the `TfLiteGPUDelegateOptions` struct in your app and set `precision_loss_allowed` to `1`, like this:\n",
"\n",
"```\n",
"//Prepare GPU delegate.\n",
"const TfLiteGpuDelegateOptions options = {\n",
" .metadata = NULL,\n",
" .compile_options = {\n",
" .precision_loss_allowed = 1, // FP16\n",
" .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,\n",
" .dynamic_batch_enabled = 0, // Not fully functional yet\n",
" },\n",
"};\n",
"```\n",
"\n",
"Detailed documentation on the TFLite GPU delegate and how to use it in your application can be found [here](https://www.tensorflow.org/lite/performance/gpu_advanced?source=post_page---------------------------)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "post_training_float16_quant.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}