forked from miha-skalic/youtube8mchallenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weights_average.py
47 lines (37 loc) · 1.71 KB
/
weights_average.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import numpy as np
import tensorflow as tf
import os
from tensorflow.python import pywrap_tensorflow
import shutil
parser = argparse.ArgumentParser(description='TF graph weights-averager.')
parser.add_argument('--save_folder', action="store", type=str, required=True)
parser.add_argument('--models', nargs='+', action="store", type=str, required=True,
help="models must be type '/path/to/model/inference_model'")
params = parser.parse_args()
in_models = params.models
save_folder = params.save_folder
# Make sure you do not overwrite
assert not os.path.isdir(save_folder), "Point to non-exisiting Directory!"
# Make sure you have correct files
for in_model in in_models:
assert os.path.isfile(in_model + ".meta")
in_model = in_models[0]
n_models = len(in_models)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
readers = [pywrap_tensorflow.NewCheckpointReader(xmodel) for xmodel in in_models]
saver = tf.train.import_meta_graph(in_model + ".meta", clear_devices=True)
global_vars = tf.global_variables()
for xtensor in global_vars:
# final_t = np.mean([xreader.get_tensor(xtensor.name.split(":")[0]) for xreader in readers], axis=0)
final_t = readers[0].get_tensor(xtensor.name.split(":")[0]
for xreader in readers[1:]:
final_t += xreader.get_tensor(xtensor.name.split(":")[0]
final_t /= n_models
xtensor.load(final_t, session=sess)
saver = tf.train.Saver(global_vars)
saver.save(sess, os.path.join(save_folder, "inference_model"))
# copy flags
ref_falgs = os.path.join(os.path.dirname(in_models[-1]), "model_flags.json")
shutil.copy(ref_falgs, save_folder)
print("We are done!")