-
Notifications
You must be signed in to change notification settings - Fork 2
/
keras_div.py
104 lines (86 loc) · 3.25 KB
/
keras_div.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import time
from argparse import ArgumentParser, Namespace
from sys import stdout
import numpy as np
from tqdm import tqdm
from tvm import relay, TVMError
from gencog.config import common_ops
from gencog.graph.relay import build_graph
from gencog.metric.div import VertexDiversity, EdgeDiversity
from gencog.spec import OpRegistry
from lemon.gen import LemonGenerator
from muffin.model_generator import MuffinGenerator
from tvm_frontend import from_keras
args = Namespace()
def _parse_args():
global args
p = ArgumentParser()
p.add_argument('-g', '--generator', type=str, choices=['lemon', 'muffin'],
help='Method for graph generation.')
p.add_argument('-l', '--limit', type=int, help='Limit on total number of operations.')
p.add_argument('-m', '--model', type=str, choices=['dag', 'template'],
help='Graph model to apply, only valid for Muffin.')
p.add_argument('-t', '--trend', action='store_true', help='Whether to record diversity trend.')
args = p.parse_args()
def main():
# Initialization
opr_limit = args.limit
if args.generator == 'lemon':
model_gen = LemonGenerator()
else:
model_gen = MuffinGenerator(args.model)
ops = [OpRegistry.get(name) for name in common_ops]
vert_div = VertexDiversity(ops)
edge_div = EdgeDiversity(ops)
# Generation loop
opr_count = 0
progress = tqdm(total=opr_limit, file=stdout)
div_record = []
if args.generator == 'lemon':
trend_file = time.strftime(f'out/lemon-%Y%m%d-%H%M%S.txt')
else:
trend_file = time.strftime(f'out/muffin-{args.model}-%Y%m%d-%H%M%S.txt')
while True:
# Generate Keras model
try:
model = model_gen.generate()
except ValueError:
continue
# Convert to Relay
batch_size = np.random.randint(1, 5)
input_shapes = {inp.name: (batch_size,) + tuple(inp.shape.as_list()[1:])
for inp in model.inputs}
mod, params = from_keras(model, shape=input_shapes)
# Check type correctness
try:
mod = relay.transform.InferType()(mod)
except TVMError:
continue
# Convert to graph representation
graph = build_graph(mod, params)
# Evaluate diversity
vert_div.evaluate(graph)
edge_div.evaluate(graph)
# Count operations
opr_num = sum(opr.op_.name_ in common_ops for opr in graph.oprs_)
opr_count += opr_num
progress.update(n=opr_num)
# Write record to file
vd, ed = vert_div.result, edge_div.result
div_record.append([opr_count, vd, ed])
progress.set_postfix_str('vert={:.4f}, edge={:.4f}'.format(vd, ed))
if args.trend:
# noinspection PyTypeChecker
np.savetxt(trend_file, np.array(div_record), fmt='%.4f')
# Stop if operation limit is reached
if opr_count >= opr_limit:
progress.close()
break
# Output diversity
np.set_printoptions(precision=3)
# print('Operator detail:', vert_div.op_div, sep='\n')
print('Vertex diversity: {:.4f}'.format(vert_div.result))
print('Edge diversity: {:.4f}'.format(edge_div.result))
if __name__ == '__main__':
_parse_args()
main()