| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for data input for speech commands.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import os |
| |
| import tensorflow as tf |
| |
| from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio |
| from tensorflow.examples.speech_commands import train |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.platform import gfile |
| from tensorflow.python.platform import test |
| |
| |
| # Used to convert a dictionary into an object, for mocking parsed flags. |
| class DictStruct(object): |
| |
| def __init__(self, **entries): |
| self.__dict__.update(entries) |
| |
| |
| class TrainTest(test.TestCase): |
| |
| def _getWavData(self): |
| with self.cached_session(): |
| sample_data = tf.zeros([32000, 2]) |
| wav_encoder = contrib_audio.encode_wav(sample_data, 16000) |
| wav_data = self.evaluate(wav_encoder) |
| return wav_data |
| |
| def _saveTestWavFile(self, filename, wav_data): |
| with open(filename, 'wb') as f: |
| f.write(wav_data) |
| |
| def _saveWavFolders(self, root_dir, labels, how_many): |
| wav_data = self._getWavData() |
| for label in labels: |
| dir_name = os.path.join(root_dir, label) |
| os.mkdir(dir_name) |
| for i in range(how_many): |
| file_path = os.path.join(dir_name, 'some_audio_%d.wav' % i) |
| self._saveTestWavFile(file_path, wav_data) |
| |
| def _prepareDummyTrainingData(self): |
| tmp_dir = self.get_temp_dir() |
| wav_dir = os.path.join(tmp_dir, 'wavs') |
| os.mkdir(wav_dir) |
| self._saveWavFolders(wav_dir, ['a', 'b', 'c'], 100) |
| background_dir = os.path.join(wav_dir, '_background_noise_') |
| os.mkdir(background_dir) |
| wav_data = self._getWavData() |
| for i in range(10): |
| file_path = os.path.join(background_dir, 'background_audio_%d.wav' % i) |
| self._saveTestWavFile(file_path, wav_data) |
| return wav_dir |
| |
| def _getDefaultFlags(self): |
| flags = { |
| 'data_url': '', |
| 'data_dir': self._prepareDummyTrainingData(), |
| 'wanted_words': 'a,b,c', |
| 'sample_rate': 16000, |
| 'clip_duration_ms': 1000, |
| 'window_size_ms': 30, |
| 'window_stride_ms': 20, |
| 'feature_bin_count': 40, |
| 'preprocess': 'mfcc', |
| 'silence_percentage': 25, |
| 'unknown_percentage': 25, |
| 'validation_percentage': 10, |
| 'testing_percentage': 10, |
| 'summaries_dir': os.path.join(self.get_temp_dir(), 'summaries'), |
| 'train_dir': os.path.join(self.get_temp_dir(), 'train'), |
| 'time_shift_ms': 100, |
| 'how_many_training_steps': '2', |
| 'learning_rate': '0.01', |
| 'quantize': False, |
| 'model_architecture': 'conv', |
| 'check_nans': False, |
| 'start_checkpoint': '', |
| 'batch_size': 1, |
| 'background_volume': 0.25, |
| 'background_frequency': 0.8, |
| 'eval_step_interval': 1, |
| 'save_step_interval': 1, |
| } |
| return DictStruct(**flags) |
| |
| @test_util.run_deprecated_v1 |
| def testTrain(self): |
| train.FLAGS = self._getDefaultFlags() |
| train.main('') |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '.pbtxt'))) |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '_labels.txt'))) |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '.ckpt-1.meta'))) |
| |
| @test_util.run_deprecated_v1 |
| def testQuantizedTrain(self): |
| train.FLAGS = self._getDefaultFlags() |
| train.FLAGS.quantize = True |
| train.FLAGS.model_architecture = 'tiny_conv' |
| train.main('') |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '.pbtxt'))) |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '_labels.txt'))) |
| self.assertTrue( |
| gfile.Exists( |
| os.path.join(train.FLAGS.train_dir, |
| train.FLAGS.model_architecture + '.ckpt-1.meta'))) |
| |
| |
| if __name__ == '__main__': |
| test.main() |