blob: d31e8990dbbc25a6d4e78c4ba7aa14f35266fe90 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "h2q27gKz1H20"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "TUfAcER1oUS6"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gb7qyhNL1yWt"
},
"source": [
"# Image classification with TensorFlow Lite Model Maker"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nDABAblytltI"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://tfhub.dev/\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m86-Nh4pMHqY"
},
"source": [
"Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.\n",
"\n",
"This notebook shows an end-to-end example that utilizes this Model Maker library to illustrate the adaption and conversion of a commonly-used image classification model to classify flowers on a mobile device."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bcLF2PKkSbV3"
},
"source": [
"## Prerequisites\n",
"\n",
"To run this example, we first need to install several required packages, including Model Maker package that in GitHub [repo](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6cv3K3oaksJv"
},
"outputs": [],
"source": [
"!pip install -q tflite-model-maker-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gx1HGRoFQ54j"
},
"source": [
"Import the required packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XtxiUeZEiXpt"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf\n",
"assert tf.__version__.startswith('2')\n",
"\n",
"from tflite_model_maker import model_spec\n",
"from tflite_model_maker import image_classifier\n",
"from tflite_model_maker.config import ExportFormat\n",
"from tflite_model_maker.config import QuantizationConfig\n",
"from tflite_model_maker.image_classifier import DataLoader\n",
"from tflite_model_maker.image_classifier import ImageSpec\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KKRaYHABpob5"
},
"source": [
"## Simple End-to-End Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SiZZ5DHXotaW"
},
"source": [
"### Get the data path\n",
"\n",
"Let's get some images to play with this simple end-to-end example. Hundreds of images is a good start for Model Maker while more data could achieve better accuracy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "3jz5x0JoskPv"
},
"outputs": [],
"source": [
"image_path = tf.keras.utils.get_file(\n",
" 'flower_photos.tgz',\n",
" 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
" extract=True)\n",
"image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a55MR6i6nuDm"
},
"source": [
"You could replace `image_path` with your own image folders. As for uploading data to colab, you could find the upload button in the left sidebar shown in the image below with the red rectangle. Just have a try to upload a zip file and unzip it. The root file path is the current path.\n",
"\n",
"\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_image_classification.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NNRNv_mloS89"
},
"source": [
"If you prefer not to upload your images to the cloud, you could try to run the library locally following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker) in GitHub."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w-VDriAdsowu"
},
"source": [
"### Run the example\n",
"The example just consists of 4 lines of code as shown below, each of which representing one step of the overall process.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ahtcO86tZBL"
},
"source": [
"Step 1. Load input data specific to an on-device ML app. Split it into training data and testing data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lANoNS_gtdH1"
},
"outputs": [],
"source": [
"data = DataLoader.from_folder(image_path)\n",
"train_data, test_data = data.split(0.9)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y_9IWyIztuRF"
},
"source": [
"Step 2. Customize the TensorFlow model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yRXMZbrwtyRD"
},
"outputs": [],
"source": [
"model = image_classifier.create(train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oxU2fDr-t2Ya"
},
"source": [
"Step 3. Evaluate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wQr02VxJt6Cs"
},
"outputs": [],
"source": [
"loss, accuracy = model.evaluate(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eVZw9zU8t84y"
},
"source": [
"Step 4. Export to TensorFlow Lite model.\n",
"\n",
"Here, we export TensorFlow Lite model with [metadata](https://www.tensorflow.org/lite/convert/metadata) which provides a standard for model descriptions. The label file is embedded in metadata.\n",
"\n",
"You could download it in the left sidebar same as the uploading part for your own use."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zb-eIzfluCoa"
},
"outputs": [],
"source": [
"model.export(export_dir='.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pyju1qc_v-wy"
},
"source": [
"After these simple 4 steps, we could further use TensorFlow Lite model file in on-device applications like in [image classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification) reference app."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R1QG32ivs9lF"
},
"source": [
"## Detailed Process\n",
"\n",
"Currently, we support several models such as EfficientNet-Lite* models, MobileNetV2, ResNet50 as pre-trained models for image classification. But it is very flexible to add new pre-trained models to this library with just a few lines of code.\n",
"\n",
"\n",
"The following walks through this end-to-end example step by step to show more detail."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ygEncJxtl-nQ"
},
"source": [
"### Step 1: Load Input Data Specific to an On-device ML App\n",
"\n",
"The flower dataset contains 3670 images belonging to 5 classes. Download the archive version of the dataset and untar it.\n",
"\n",
"The dataset has the following directory structure:\n",
"\n",
"\u003cpre\u003e\n",
"\u003cb\u003eflower_photos\u003c/b\u003e\n",
"|__ \u003cb\u003edaisy\u003c/b\u003e\n",
" |______ 100080576_f52e8ee070_n.jpg\n",
" |______ 14167534527_781ceb1b7a_n.jpg\n",
" |______ ...\n",
"|__ \u003cb\u003edandelion\u003c/b\u003e\n",
" |______ 10043234166_e6dd915111_n.jpg\n",
" |______ 1426682852_e62169221f_m.jpg\n",
" |______ ...\n",
"|__ \u003cb\u003eroses\u003c/b\u003e\n",
" |______ 102501987_3cdb8e5394_n.jpg\n",
" |______ 14982802401_a3dfb22afb.jpg\n",
" |______ ...\n",
"|__ \u003cb\u003esunflowers\u003c/b\u003e\n",
" |______ 12471791574_bb1be83df4.jpg\n",
" |______ 15122112402_cafa41934f.jpg\n",
" |______ ...\n",
"|__ \u003cb\u003etulips\u003c/b\u003e\n",
" |______ 13976522214_ccec508fe7.jpg\n",
" |______ 14487943607_651e8062a1_m.jpg\n",
" |______ ...\n",
"\u003c/pre\u003e"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7tOfUr2KlgpU"
},
"outputs": [],
"source": [
"image_path = tf.keras.utils.get_file(\n",
" 'flower_photos.tgz',\n",
" 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
" extract=True)\n",
"image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E051HBUM5owi"
},
"source": [
"Use `DataLoader` class to load data.\n",
"\n",
"As for `from_folder()` method, it could load data from the folder. It assumes that the image data of the same class are in the same subdirectory and the subfolder name is the class name. Currently, JPEG-encoded images and PNG-encoded images are supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I_fOlZsklmlL"
},
"outputs": [],
"source": [
"data = DataLoader.from_folder(image_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u501eT4koURB"
},
"source": [
"Split it to training data (80%), validation data (10%, optional) and testing data (10%)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cY4UU5SUobtJ"
},
"outputs": [],
"source": [
"train_data, rest_data = data.split(0.8)\n",
"validation_data, test_data = rest_data.split(0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z9_MYPie3EMO"
},
"source": [
"Show 25 image examples with labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ih4Wx44I482b"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):\n",
" plt.subplot(5,5,i+1)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.grid(False)\n",
" plt.imshow(image.numpy(), cmap=plt.cm.gray)\n",
" plt.xlabel(data.index_to_label[label.numpy()])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AWuoensX4vDA"
},
"source": [
"### Step 2: Customize the TensorFlow Model\n",
"\n",
"Create a custom image classifier model based on the loaded data. The default model is EfficientNet-Lite0.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TvYSUuJY3QxR"
},
"outputs": [],
"source": [
"model = image_classifier.create(train_data, validation_data=validation_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4JFOKWnH9x8_"
},
"source": [
"Have a look at the detailed model structure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QNXAfjl192dC"
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LP5FPk_tOxoZ"
},
"source": [
"### Step 3: Evaluate the Customized Model\n",
"\n",
"Evaluate the result of the model, get the loss and accuracy of the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A8c2ZQ0J3Riy"
},
"outputs": [],
"source": [
"loss, accuracy = model.evaluate(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ZCrYOWoCt05"
},
"source": [
"We could plot the predicted results in 100 test images. Predicted labels with red color are the wrong predicted results while others are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n9O9Kx7nDQWD"
},
"outputs": [],
"source": [
"# A helper function that returns 'red'/'black' depending on if its two input\n",
"# parameter matches or not.\n",
"def get_label_color(val1, val2):\n",
" if val1 == val2:\n",
" return 'black'\n",
" else:\n",
" return 'red'\n",
"\n",
"# Then plot 100 test images and their predicted labels.\n",
"# If a prediction result is different from the label provided label in \"test\"\n",
"# dataset, we will highlight it in red color.\n",
"plt.figure(figsize=(20, 20))\n",
"predicts = model.predict_top_k(test_data)\n",
"for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):\n",
" ax = plt.subplot(10, 10, i+1)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.grid(False)\n",
" plt.imshow(image.numpy(), cmap=plt.cm.gray)\n",
"\n",
" predict_label = predicts[i][0][0]\n",
" color = get_label_color(predict_label,\n",
" test_data.index_to_label[label.numpy()])\n",
" ax.xaxis.label.set_color(color)\n",
" plt.xlabel('Predicted: %s' % predict_label)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S3H0rkbLUZAG"
},
"source": [
"If the accuracy doesn't meet the app requirement, one could refer to [Advanced Usage](#scrollTo=zNDBP2qA54aK) to explore alternatives such as changing to a larger model, adjusting re-training parameters etc."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aeHoGAceO2xV"
},
"source": [
"### Step 4: Export to TensorFlow Lite Model\n",
"\n",
"Convert the existing model to TensorFlow Lite model format with [metadata](https://www.tensorflow.org/lite/convert/metadata). The default TFLite filename is `model.tflite`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im6wA9lK3TQB"
},
"outputs": [],
"source": [
"model.export(export_dir='.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ROS2Ay2jMPCl"
},
"source": [
"See [example applications and guides of image classification](https://www.tensorflow.org/lite/models/image_classification/overview#example_applications_and_guides) for more details about how to integrate the TensorFlow Lite model into mobile apps."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "habFnvRxxQ4A"
},
"source": [
"The allowed export formats can be one or a list of the following:\n",
"\n",
"* `ExportFormat.TFLITE`\n",
"* `ExportFormat.LABEL`\n",
"* `ExportFormat.SAVED_MODEL`\n",
"\n",
"By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the label file as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BvxWsOTmKG4P"
},
"outputs": [],
"source": [
"model.export(export_dir='.', export_format=ExportFormat.LABEL)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4jQaxyT5_KV"
},
"source": [
"You can also evaluate the tflite model with the `evaluate_tflite` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S1YoPX5wOK-u"
},
"outputs": [],
"source": [
"model.evaluate_tflite('model.tflite', test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zNDBP2qA54aK"
},
"source": [
"## Advanced Usage\n",
"\n",
"The `create` function is the critical part of this library. It uses transfer learning with a pretrained model similar to the [tutorial](https://www.tensorflow.org/tutorials/images/transfer_learning).\n",
"\n",
"The `create` function contains the following steps:\n",
"\n",
"1. Split the data into training, validation, testing data according to parameter `validation_ratio` and `test_ratio`. The default value of `validation_ratio` and `test_ratio` are `0.1` and `0.1`.\n",
"2. Download a [Image Feature Vector](https://www.tensorflow.org/hub/common_signatures/images#image_feature_vector) as the base model from TensorFlow Hub. The default pre-trained model is EfficientNet-Lite0.\n",
"3. Add a classifier head with a Dropout Layer with `dropout_rate` between head layer and pre-trained model. The default `dropout_rate` is the default `dropout_rate` value from [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) by TensorFlow Hub.\n",
"4. Preprocess the raw input data. Currently, preprocessing steps including normalizing the value of each image pixel to model input scale and resizing it to model input size. EfficientNet-Lite0 have the input scale `[0, 1]` and the input image size `[224, 224, 3]`.\n",
"5. Feed the data into the classifier model. By default, the training parameters such as training epochs, batch size, learning rate, momentum are the default values from [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) by TensorFlow Hub. Only the classifier head is trained.\n",
"\n",
"\n",
"In this section, we describe several advanced topics, including switching to a different image classification model, changing the training hyperparameters etc.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gc4Jk8TvBQfm"
},
"source": [
"## Post-training quantization on the TensorFLow Lite model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tD8BOYrHBiDt"
},
"source": [
"[Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator latency, with little degradation in model accuracy. Thus, it's widely used to optimize the model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iyIo0d5TCzE2"
},
"source": [
"Model Maker supports multiple post-training quantization options. Let's take full integer quantization as an instance. First, define the quantization config to enforce full integer quantization for all ops including the input and output. The input type and output type are `uint8` by default. You may also change them to other types like `int8` by setting `inference_input_type` and `inference_output_type` in config."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k8hL2mstCxQl"
},
"outputs": [],
"source": [
"config = QuantizationConfig.for_int8(representative_data=test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K1gzx_rmFMOA"
},
"source": [
"Then we export TensorFlow Lite model with such configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WTJzFQnJFMjr"
},
"outputs": [],
"source": [
"model.export(export_dir='.', tflite_filename='model_quant.tflite', quantization_config=config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Safo0e40wKZW"
},
"source": [
"In Colab, you can download the model named `model_quant.tflite` from the left sidebar, same as the uploading part mentioned above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A4kiTJtZ_sDm"
},
"source": [
"## Change the model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "794vgj6ud7Ep"
},
"source": [
"### Change to the model that's supported in this library.\n",
"\n",
"This library supports EfficientNet-Lite models, MobileNetV2, ResNet50 by now. [EfficientNet-Lite](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) are a family of image classification models that could achieve state-of-art accuracy and suitable for Edge devices. The default model is EfficientNet-Lite0.\n",
"\n",
"We could switch model to MobileNetV2 by just setting parameter `model_spec` to the MobileNetV2 model specification in `create` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7JKsJ6-P6ae1"
},
"outputs": [],
"source": [
"model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gm_B1Wv08AxR"
},
"source": [
"Evaluate the newly retrained MobileNetV2 model to see the accuracy and loss in testing data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lB2Go3HW8X7_"
},
"outputs": [],
"source": [
"loss, accuracy = model.evaluate(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vAciGzVWtmWp"
},
"source": [
"### Change to the model in TensorFlow Hub\n",
"\n",
"Moreover, we could also switch to other new models that inputs an image and outputs a feature vector with TensorFlow Hub format.\n",
"\n",
"As [Inception V3](https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1) model as an example, we could define `inception_v3_spec` which is an object of `ImageSpec` and contains the specification of the Inception V3 model.\n",
"\n",
"We need to specify the model name `name`, the url of the TensorFlow Hub model `uri`. Meanwhile, the default value of `input_image_shape` is `[224, 224]`. We need to change it to `[299, 299]` for Inception V3 model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xdiMF2WMfAR4"
},
"outputs": [],
"source": [
"inception_v3_spec = ImageSpec(\n",
" uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')\n",
"inception_v3_spec.input_image_shape = [299, 299]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T_GGIoXZCs5F"
},
"source": [
"Then, by setting parameter `model_spec` to `inception_v3_spec` in `create` method, we could retrain the Inception V3 model.\n",
"\n",
"The remaining steps are exactly same and we could get a customized InceptionV3 TensorFlow Lite model in the end."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UhZ5IRKdeex3"
},
"source": [
"### Change your own custom model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "svTjlZhrCrcV"
},
"source": [
"If we'd like to use the custom model that's not in TensorFlow Hub, we should create and export [ModelSpec](https://www.tensorflow.org/hub/api_docs/python/hub/ModuleSpec) in TensorFlow Hub.\n",
"\n",
"Then start to define `ImageSpec` object like the process above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4M9bn703AHt2"
},
"source": [
"## Change the training hyperparameters\n",
"We could also change the training hyperparameters like `epochs`, `dropout_rate` and `batch_size` that could affect the model accuracy. The model parameters you can adjust are:\n",
"\n",
"\n",
"* `epochs`: more epochs could achieve better accuracy until it converges but training for too many epochs may lead to overfitting.\n",
"* `dropout_rate`: The rate for dropout, avoid overfitting. None by default.\n",
"* `batch_size`: number of samples to use in one training step. None by default.\n",
"* `validation_data`: Validation data. If None, skips validation process. None by default.\n",
"* `train_whole_model`: If true, the Hub module is trained together with the classification layer on top. Otherwise, only train the top classification layer. None by default.\n",
"* `learning_rate`: Base learning rate. None by default.\n",
"* `momentum`: a Python float forwarded to the optimizer. Only used when\n",
" `use_hub_library` is True. None by default.\n",
"* `shuffle`: Boolean, whether the data should be shuffled. False by default.\n",
"* `use_augmentation`: Boolean, use data augmentation for preprocessing. False by default.\n",
"* `use_hub_library`: Boolean, use `make_image_classifier_lib` from tensorflow hub to retrain the model. This training pipeline could achieve better performance for complicated dataset with many categories. True by default. \n",
"* `warmup_steps`: Number of warmup steps for warmup schedule on learning rate. If None, the default warmup_steps is used which is the total training steps in two epochs. Only used when `use_hub_library` is False. None by default.\n",
"* `model_dir`: Optional, the location of the model checkpoint files. Only used when `use_hub_library` is False. None by default.\n",
"\n",
"Parameters which are None by default like `epochs` will get the concrete default parameters in [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/02ab9b7d3455e99e97abecf43c5d598a5528e20c/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L54) from TensorFlow Hub library or [train_image_classifier_lib](https://github.com/tensorflow/examples/blob/f0260433d133fd3cea4a920d1e53ecda07163aee/tensorflow_examples/lite/model_maker/core/task/train_image_classifier_lib.py#L61).\n",
"\n",
"For example, we could train with more epochs.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A3k7mhH54QcK"
},
"outputs": [],
"source": [
"model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VaYBQymQDsXU"
},
"source": [
"Evaluate the newly retrained model with 10 training epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VafIYpKWD4Sw"
},
"outputs": [],
"source": [
"loss, accuracy = model.evaluate(test_data)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "model_maker_image_classification.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}