Compute Library
 22.05
tensorflow_data_extractor.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 """Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
3 Usage
4  python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file
5 
6 Saves each variable to a {variable_name}.npy binary file.
7 
8 Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
9  {model_name}.data-{step}-of-{max_step}
10 instead of:
11  {model_name}.ckpt
12 When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
13 when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
14 
15 Also note that this script relies on the parameters to be extracted being in the
16 'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
17 specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
18 collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
19 
20 Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
21 """
22 import argparse
23 import numpy as np
24 import os
25 import tensorflow as tf
26 
27 
28 if __name__ == "__main__":
29  # Parse arguments
30  parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
31  parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
32  file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
33  model name with ".ckpt" extension')
34  parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
35  args = parser.parse_args()
36 
37  # Load Tensorflow Net
38  saver = tf.train.import_meta_graph(args.netFile)
39  with tf.Session() as sess:
40  # Restore session
41  saver.restore(sess, args.modelFile)
42  print('Model restored.')
43  # Save trainable variables to numpy arrays
44  for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
45  varname = t.name
46  if os.path.sep in t.name:
47  varname = varname.replace(os.path.sep, '_')
48  print("Renaming variable {0} to {1}".format(t.name, varname))
49  print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
50  # Dump as binary
51  np.save(varname, sess.run(t))