Compute Library
 21.08
caffe_mnist_image_extractor.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 """Extracts mnist image data from the Caffe data files and stores them in numpy arrays
3 Usage
4  python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path
5 
6 Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the
7 corresponding labels to labels100.txt.
8 
9 Tested with Caffe 1.0 on Python 2.7
10 """
11 import argparse
12 import os
13 import struct
14 import numpy as np
15 from array import array
16 
17 
18 if __name__ == "__main__":
19  # Parse arguments
20  parser = argparse.ArgumentParser('Extract Caffe mnist image data')
21  parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory')
22  parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)')
23  args = parser.parse_args()
24 
25  images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte')
26  labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte')
27 
28  images_file = open(images_filename, 'rb')
29  labels_file = open(labels_filename, 'rb')
30  images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16))
31  labels_magic, labels_size = struct.unpack('>II', labels_file.read(8))
32  images = array('B', images_file.read())
33  labels = array('b', labels_file.read())
34 
35  input10_path = os.path.join(args.outDir, 'input10.npy')
36  input100_path = os.path.join(args.outDir, 'input100.npy')
37  labels100_path = os.path.join(args.outDir, 'labels100.npy')
38 
39  outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32)
40  outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32)
41  labels_output = open(labels100_path, 'w')
42  for i in xrange(100):
43  image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0
44  outputs_100[i, :, :, 0] = image
45 
46  if i < 10:
47  outputs_10[i, :, :, 0] = image
48 
49  if i == 10:
50  np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2)))
51  print "Wrote", input10_path
52 
53  labels_output.write(str(labels[i]) + '\n')
54 
55  labels_output.close()
56  print "Wrote", labels100_path
57 
58  np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2)))
59  print "Wrote", input100_path