-
Notifications
You must be signed in to change notification settings - Fork 1
/
frozen_graph.py
executable file
·116 lines (80 loc) · 2.97 KB
/
frozen_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import app
import os
import PIL.Image as Image
import numpy as np
FLAGS = None
def freeze_graph():
"""
freeze the saved checkpoints/graph to *.pb
"""
checkpoint = tf.train.get_checkpoint_state(FLAGS.input_checkpoint)
input_checkpoint = checkpoint.model_checkpoint_path
output_graph = os.path.join(FLAGS.input_checkpoint, FLAGS.output_graph)
saver = tf.train.import_meta_graph(input_checkpoint + ".meta",
clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(sess,
input_graph_def,
FLAGS.output_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph" % (len(output_graph_def.node)))
def load_graph(frozen_graph_filename):
"""
Loads Frozen graph
"""
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
return graph
def main(unused_args):
freeze_graph()
frozen_graph_path = os.path.join(FLAGS.input_checkpoint, FLAGS.output_graph)
graph = load_graph(frozen_graph_path)
for op in graph.get_operations():
print(op.name)
input_operation = graph.get_operation_by_name('import/'+FLAGS.input_names)
print(input_operation.outputs[0])
output_operation = graph.get_operation_by_name('import/'+FLAGS.output_names)
print(output_operation.outputs[0])
return 0
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_checkpoint",
type=str,
default="tf_files/inception/",
help="TensorFlow variables file to load.")
parser.add_argument(
"--output_graph",
type=str,
default="frozen_graph.pb",
help="Output \'GraphDef\' file name.")
parser.add_argument(
"--input_names",
type=str,
default="DecodeJpeg",
help="Input node names, comma separated.")
parser.add_argument(
"--output_names",
type=str,
default="final_result",
help="Output node names, comma separated.")
return parser.parse_known_args()
if __name__ == "__main__":
FLAGS, unparsed = parse_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)