Go to the source code of this file.
|
| parser = argparse.ArgumentParser('Extract Caffe mnist image data') |
|
| dest |
|
| type |
|
| str |
|
| required |
|
| True |
|
| help |
|
| default |
|
| args = parser.parse_args() |
|
| images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte') |
|
| labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte') |
|
| images_file = open(images_filename, 'rb') |
|
| labels_file = open(labels_filename, 'rb') |
|
| images_magic |
|
| images_size |
|
| rows |
|
| cols |
|
| labels_magic |
|
| labels_size |
|
| images = array('B', images_file.read()) |
|
| labels = array('b', labels_file.read()) |
|
| input10_path = os.path.join(args.outDir, 'input10.npy') |
|
| input100_path = os.path.join(args.outDir, 'input100.npy') |
|
| labels100_path = os.path.join(args.outDir, 'labels100.npy') |
|
| outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32) |
|
| outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32) |
|
| labels_output = open(labels100_path, 'w') |
|
float | image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0 |
|