| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| r"""Simple speech recognition to spot a limited number of keywords. |
| |
| This is a self-contained example script that will train a very basic audio |
| recognition model in TensorFlow. It downloads the necessary training data and |
| runs with reasonable defaults to train within a few hours even only using a CPU. |
| For more information, please see |
| https://www.tensorflow.org/tutorials/audio_recognition. |
| |
| It is intended as an introduction to using neural networks for audio |
| recognition, and is not a full speech recognition system. For more advanced |
| speech systems, I recommend looking into Kaldi. This network uses a keyword |
| detection style to spot discrete words from a small vocabulary, consisting of |
| "yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go". |
| |
| To run the training process, use: |
| |
| bazel run tensorflow/examples/speech_commands:train |
| |
| This will write out checkpoints to /tmp/speech_commands_train/, and will |
| download over 1GB of open source training data, so you'll need enough free space |
| and a good internet connection. The default data is a collection of thousands of |
| one-second .wav files, each containing one spoken word. This data set is |
| collected from https://aiyprojects.withgoogle.com/open_speech_recording, please |
| consider contributing to help improve this and other models! |
| |
| As training progresses, it will print out its accuracy metrics, which should |
| rise above 90% by the end. Once it's complete, you can run the freeze script to |
| get a binary GraphDef that you can easily deploy on mobile applications. |
| |
| If you want to train on your own data, you'll need to create .wavs with your |
| recordings, all at a consistent length, and then arrange them into subfolders |
| organized by label. For example, here's a possible file structure: |
| |
| my_wavs > |
| up > |
| audio_0.wav |
| audio_1.wav |
| down > |
| audio_2.wav |
| audio_3.wav |
| other> |
| audio_4.wav |
| audio_5.wav |
| |
| You'll also need to tell the script what labels to look for, using the |
| `--wanted_words` argument. In this case, 'up,down' might be what you want, and |
| the audio in the 'other' folder would be used to train an 'unknown' category. |
| |
| To pull this all together, you'd run: |
| |
| bazel run tensorflow/examples/speech_commands:train -- \ |
| --data_dir=my_wavs --wanted_words=up,down |
| |
| """ |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import argparse |
| import os.path |
| import sys |
| |
| import numpy as np |
| from six.moves import xrange # pylint: disable=redefined-builtin |
| import tensorflow as tf |
| |
| import input_data |
| import models |
| from tensorflow.python.platform import gfile |
| |
| FLAGS = None |
| |
| |
| def main(_): |
| # We want to see all the logging messages for this tutorial. |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) |
| |
| # Start a new TensorFlow session. |
| sess = tf.compat.v1.InteractiveSession() |
| |
| # Begin by making sure we have the training data we need. If you already have |
| # training data of your own, use `--data_url= ` on the command line to avoid |
| # downloading. |
| model_settings = models.prepare_model_settings( |
| len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), |
| FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, |
| FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess) |
| audio_processor = input_data.AudioProcessor( |
| FLAGS.data_url, FLAGS.data_dir, |
| FLAGS.silence_percentage, FLAGS.unknown_percentage, |
| FLAGS.wanted_words.split(','), FLAGS.validation_percentage, |
| FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir) |
| fingerprint_size = model_settings['fingerprint_size'] |
| label_count = model_settings['label_count'] |
| time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) |
| # Figure out the learning rates for each training phase. Since it's often |
| # effective to have high learning rates at the start of training, followed by |
| # lower levels towards the end, the number of steps and learning rates can be |
| # specified as comma-separated lists to define the rate at each stage. For |
| # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001 |
| # will run 13,000 training loops in total, with a rate of 0.001 for the first |
| # 10,000, and 0.0001 for the final 3,000. |
| training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(','))) |
| learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) |
| if len(training_steps_list) != len(learning_rates_list): |
| raise Exception( |
| '--how_many_training_steps and --learning_rate must be equal length ' |
| 'lists, but are %d and %d long instead' % (len(training_steps_list), |
| len(learning_rates_list))) |
| |
| input_placeholder = tf.compat.v1.placeholder( |
| tf.float32, [None, fingerprint_size], name='fingerprint_input') |
| if FLAGS.quantize: |
| fingerprint_min, fingerprint_max = input_data.get_features_range( |
| model_settings) |
| fingerprint_input = tf.quantization.fake_quant_with_min_max_args( |
| input_placeholder, fingerprint_min, fingerprint_max) |
| else: |
| fingerprint_input = input_placeholder |
| |
| logits, dropout_prob = models.create_model( |
| fingerprint_input, |
| model_settings, |
| FLAGS.model_architecture, |
| is_training=True) |
| |
| # Define loss and optimizer |
| ground_truth_input = tf.compat.v1.placeholder( |
| tf.int64, [None], name='groundtruth_input') |
| |
| # Optionally we can add runtime checks to spot when NaNs or other symptoms of |
| # numerical errors start occurring during training. |
| control_dependencies = [] |
| if FLAGS.check_nans: |
| checks = tf.compat.v1.add_check_numerics_ops() |
| control_dependencies = [checks] |
| |
| # Create the back propagation and training evaluation machinery in the graph. |
| with tf.compat.v1.name_scope('cross_entropy'): |
| cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy( |
| labels=ground_truth_input, logits=logits) |
| if FLAGS.quantize: |
| tf.contrib.quantize.create_training_graph(quant_delay=0) |
| with tf.compat.v1.name_scope('train'), tf.control_dependencies(control_dependencies): |
| learning_rate_input = tf.compat.v1.placeholder( |
| tf.float32, [], name='learning_rate_input') |
| train_step = tf.compat.v1.train.GradientDescentOptimizer( |
| learning_rate_input).minimize(cross_entropy_mean) |
| predicted_indices = tf.argmax(input=logits, axis=1) |
| correct_prediction = tf.equal(predicted_indices, ground_truth_input) |
| confusion_matrix = tf.math.confusion_matrix( |
| labels=ground_truth_input, predictions=predicted_indices, num_classes=label_count) |
| evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction, tf.float32)) |
| with tf.compat.v1.get_default_graph().name_scope('eval'): |
| tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean) |
| tf.compat.v1.summary.scalar('accuracy', evaluation_step) |
| |
| global_step = tf.compat.v1.train.get_or_create_global_step() |
| increment_global_step = tf.compat.v1.assign(global_step, global_step + 1) |
| |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
| |
| # Merge all the summaries and write them out to /tmp/retrain_logs (by default) |
| merged_summaries = tf.compat.v1.summary.merge_all(scope='eval') |
| train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train', |
| sess.graph) |
| validation_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/validation') |
| |
| tf.compat.v1.global_variables_initializer().run() |
| |
| start_step = 1 |
| |
| if FLAGS.start_checkpoint: |
| models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) |
| start_step = global_step.eval(session=sess) |
| |
| tf.compat.v1.logging.info('Training from step: %d ', start_step) |
| |
| # Save graph.pbtxt. |
| tf.io.write_graph(sess.graph_def, FLAGS.train_dir, |
| FLAGS.model_architecture + '.pbtxt') |
| |
| # Save list of words. |
| with gfile.GFile( |
| os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), |
| 'w') as f: |
| f.write('\n'.join(audio_processor.words_list)) |
| |
| # Training loop. |
| training_steps_max = np.sum(training_steps_list) |
| for training_step in xrange(start_step, training_steps_max + 1): |
| # Figure out what the current learning rate is. |
| training_steps_sum = 0 |
| for i in range(len(training_steps_list)): |
| training_steps_sum += training_steps_list[i] |
| if training_step <= training_steps_sum: |
| learning_rate_value = learning_rates_list[i] |
| break |
| # Pull the audio samples we'll use for training. |
| train_fingerprints, train_ground_truth = audio_processor.get_data( |
| FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, |
| FLAGS.background_volume, time_shift_samples, 'training', sess) |
| # Run the graph with this batch of training data. |
| train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( |
| [ |
| merged_summaries, |
| evaluation_step, |
| cross_entropy_mean, |
| train_step, |
| increment_global_step, |
| ], |
| feed_dict={ |
| fingerprint_input: train_fingerprints, |
| ground_truth_input: train_ground_truth, |
| learning_rate_input: learning_rate_value, |
| dropout_prob: 0.5 |
| }) |
| train_writer.add_summary(train_summary, training_step) |
| tf.compat.v1.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % |
| (training_step, learning_rate_value, train_accuracy * 100, |
| cross_entropy_value)) |
| is_last_step = (training_step == training_steps_max) |
| if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: |
| set_size = audio_processor.set_size('validation') |
| total_accuracy = 0 |
| total_conf_matrix = None |
| for i in xrange(0, set_size, FLAGS.batch_size): |
| validation_fingerprints, validation_ground_truth = ( |
| audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, |
| 0.0, 0, 'validation', sess)) |
| # Run a validation step and capture training summaries for TensorBoard |
| # with the `merged` op. |
| validation_summary, validation_accuracy, conf_matrix = sess.run( |
| [merged_summaries, evaluation_step, confusion_matrix], |
| feed_dict={ |
| fingerprint_input: validation_fingerprints, |
| ground_truth_input: validation_ground_truth, |
| dropout_prob: 1.0 |
| }) |
| validation_writer.add_summary(validation_summary, training_step) |
| batch_size = min(FLAGS.batch_size, set_size - i) |
| total_accuracy += (validation_accuracy * batch_size) / set_size |
| if total_conf_matrix is None: |
| total_conf_matrix = conf_matrix |
| else: |
| total_conf_matrix += conf_matrix |
| tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) |
| tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % |
| (training_step, total_accuracy * 100, set_size)) |
| |
| # Save the model checkpoint periodically. |
| if (training_step % FLAGS.save_step_interval == 0 or |
| training_step == training_steps_max): |
| checkpoint_path = os.path.join(FLAGS.train_dir, |
| FLAGS.model_architecture + '.ckpt') |
| tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) |
| saver.save(sess, checkpoint_path, global_step=training_step) |
| |
| set_size = audio_processor.set_size('testing') |
| tf.compat.v1.logging.info('set_size=%d', set_size) |
| total_accuracy = 0 |
| total_conf_matrix = None |
| for i in xrange(0, set_size, FLAGS.batch_size): |
| test_fingerprints, test_ground_truth = audio_processor.get_data( |
| FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) |
| test_accuracy, conf_matrix = sess.run( |
| [evaluation_step, confusion_matrix], |
| feed_dict={ |
| fingerprint_input: test_fingerprints, |
| ground_truth_input: test_ground_truth, |
| dropout_prob: 1.0 |
| }) |
| batch_size = min(FLAGS.batch_size, set_size - i) |
| total_accuracy += (test_accuracy * batch_size) / set_size |
| if total_conf_matrix is None: |
| total_conf_matrix = conf_matrix |
| else: |
| total_conf_matrix += conf_matrix |
| tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) |
| tf.compat.v1.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, |
| set_size)) |
| |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '--data_url', |
| type=str, |
| # pylint: disable=line-too-long |
| default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz', |
| # pylint: enable=line-too-long |
| help='Location of speech training data archive on the web.') |
| parser.add_argument( |
| '--data_dir', |
| type=str, |
| default='/tmp/speech_dataset/', |
| help="""\ |
| Where to download the speech training data to. |
| """) |
| parser.add_argument( |
| '--background_volume', |
| type=float, |
| default=0.1, |
| help="""\ |
| How loud the background noise should be, between 0 and 1. |
| """) |
| parser.add_argument( |
| '--background_frequency', |
| type=float, |
| default=0.8, |
| help="""\ |
| How many of the training samples have background noise mixed in. |
| """) |
| parser.add_argument( |
| '--silence_percentage', |
| type=float, |
| default=10.0, |
| help="""\ |
| How much of the training data should be silence. |
| """) |
| parser.add_argument( |
| '--unknown_percentage', |
| type=float, |
| default=10.0, |
| help="""\ |
| How much of the training data should be unknown words. |
| """) |
| parser.add_argument( |
| '--time_shift_ms', |
| type=float, |
| default=100.0, |
| help="""\ |
| Range to randomly shift the training audio by in time. |
| """) |
| parser.add_argument( |
| '--testing_percentage', |
| type=int, |
| default=10, |
| help='What percentage of wavs to use as a test set.') |
| parser.add_argument( |
| '--validation_percentage', |
| type=int, |
| default=10, |
| help='What percentage of wavs to use as a validation set.') |
| parser.add_argument( |
| '--sample_rate', |
| type=int, |
| default=16000, |
| help='Expected sample rate of the wavs',) |
| parser.add_argument( |
| '--clip_duration_ms', |
| type=int, |
| default=1000, |
| help='Expected duration in milliseconds of the wavs',) |
| parser.add_argument( |
| '--window_size_ms', |
| type=float, |
| default=30.0, |
| help='How long each spectrogram timeslice is.',) |
| parser.add_argument( |
| '--window_stride_ms', |
| type=float, |
| default=10.0, |
| help='How far to move in time between spectogram timeslices.',) |
| parser.add_argument( |
| '--feature_bin_count', |
| type=int, |
| default=40, |
| help='How many bins to use for the MFCC fingerprint', |
| ) |
| parser.add_argument( |
| '--how_many_training_steps', |
| type=str, |
| default='15000,3000', |
| help='How many training loops to run',) |
| parser.add_argument( |
| '--eval_step_interval', |
| type=int, |
| default=400, |
| help='How often to evaluate the training results.') |
| parser.add_argument( |
| '--learning_rate', |
| type=str, |
| default='0.001,0.0001', |
| help='How large a learning rate to use when training.') |
| parser.add_argument( |
| '--batch_size', |
| type=int, |
| default=100, |
| help='How many items to train with at once',) |
| parser.add_argument( |
| '--summaries_dir', |
| type=str, |
| default='/tmp/retrain_logs', |
| help='Where to save summary logs for TensorBoard.') |
| parser.add_argument( |
| '--wanted_words', |
| type=str, |
| default='yes,no,up,down,left,right,on,off,stop,go', |
| help='Words to use (others will be added to an unknown label)',) |
| parser.add_argument( |
| '--train_dir', |
| type=str, |
| default='/tmp/speech_commands_train', |
| help='Directory to write event logs and checkpoint.') |
| parser.add_argument( |
| '--save_step_interval', |
| type=int, |
| default=100, |
| help='Save model checkpoint every save_steps.') |
| parser.add_argument( |
| '--start_checkpoint', |
| type=str, |
| default='', |
| help='If specified, restore this pretrained model before any training.') |
| parser.add_argument( |
| '--model_architecture', |
| type=str, |
| default='conv', |
| help='What model architecture to use') |
| parser.add_argument( |
| '--check_nans', |
| type=bool, |
| default=False, |
| help='Whether to check for invalid numbers during processing') |
| parser.add_argument( |
| '--quantize', |
| type=bool, |
| default=False, |
| help='Whether to train the model for eight-bit deployment') |
| parser.add_argument( |
| '--preprocess', |
| type=str, |
| default='mfcc', |
| help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"') |
| |
| FLAGS, unparsed = parser.parse_known_args() |
| tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) |