2 """ Extract trainable parameters from a frozen model and stores them in numpy arrays.
4 python tf_frozen_model_extractor -m path_to_frozem_model -d path_to_store_the_parameters
6 Saves each variable to a {variable_name}.npy binary file.
8 Note that the script permutes the trainable parameters to NCHW format. This is a pretty manual step thus it's not thoroughly tested.
13 import tensorflow
as tf
14 from tensorflow.python.platform
import gfile
16 strings_to_remove=[
"read",
"/:0"]
17 permutations = { 1 : [0], 2 : [1, 0], 3 : [2, 1, 0], 4 : [3, 2, 0, 1]}
19 if __name__ ==
"__main__":
21 parser = argparse.ArgumentParser(
'Extract TensorFlow net parameters')
22 parser.add_argument(
'-m', dest=
'modelFile', type=str, required=
True, help=
'Path to TensorFlow frozen graph file (.pb)')
23 parser.add_argument(
'-d', dest=
'dumpPath', type=str, required=
False, default=
'./', help=
'Path to store the resulting files.')
24 parser.add_argument(
'--nostore', dest=
'storeRes', action=
'store_false', help=
'Specify if files should not be stored. Used for debugging.')
25 parser.set_defaults(storeRes=
True)
26 args = parser.parse_args()
29 if not os.path.exists(args.dumpPath):
30 os.makedirs(args.dumpPath)
33 with tf.Graph().as_default()
as graph:
34 with tf.Session()
as sess:
35 print(
"Loading model.")
36 with gfile.FastGFile(args.modelFile,
'rb')
as f:
37 graph_def = tf.GraphDef()
38 graph_def.ParseFromString(f.read())
39 sess.graph.as_default()
41 tf.import_graph_def(graph_def, input_map=
None, return_elements=
None, name=
"", op_dict=
None, producer_op_list=
None)
43 for op
in graph.get_operations():
44 for op_val
in op.values():
50 tT = t.transpose(permutations[len(t.shape)])
51 t = np.ascontiguousarray(tT)
53 for s
in strings_to_remove:
54 varname = varname.replace(s,
"")
55 if os.path.sep
in varname:
56 varname = varname.replace(os.path.sep,
'_')
57 print(
"Renaming variable {0} to {1}".format(op_val.name, varname))
61 print(
"Saving variable {0} with shape {1} ...".format(varname, t.shape))
62 np.save(os.path.join(args.dumpPath, varname), t)