blob: 2c9478a61131233fb3a2175850a33be2ef631db4 [file] [log] [blame]
#!/usr/bin/env python
"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays
Usage
python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path
Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the
corresponding labels to labels100.txt.
Tested with Caffe 1.0 on Python 2.7
"""
import argparse
import os
import struct
import numpy as np
from array import array
if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser('Extract Caffe mnist image data')
parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory')
parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)')
args = parser.parse_args()
images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte')
labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte')
images_file = open(images_filename, 'rb')
labels_file = open(labels_filename, 'rb')
images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16))
labels_magic, labels_size = struct.unpack('>II', labels_file.read(8))
images = array('B', images_file.read())
labels = array('b', labels_file.read())
input10_path = os.path.join(args.outDir, 'input10.npy')
input100_path = os.path.join(args.outDir, 'input100.npy')
labels100_path = os.path.join(args.outDir, 'labels100.npy')
outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32)
outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32)
labels_output = open(labels100_path, 'w')
for i in xrange(100):
image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0
outputs_100[i, :, :, 0] = image
if i < 10:
outputs_10[i, :, :, 0] = image
if i == 10:
np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2)))
print "Wrote", input10_path
labels_output.write(str(labels[i]) + '\n')
labels_output.close()
print "Wrote", labels100_path
np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2)))
print "Wrote", input100_path