-
Notifications
You must be signed in to change notification settings - Fork 5
/
model_representation.py
199 lines (180 loc) · 6.33 KB
/
model_representation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from linklist import Linklist
class ModelStruction:
'''
定义了对模型结构改造的类,按分支输出链路结构
'''
def __init__(self, nodes):
for k, v in nodes.items():
if not v['args']:
# 计算题图的头节点
self.head = k
break
# 维护推理顺序
self.seq = 0
# 储存转化后的分支划分结果
self.blocks = list()
# 保存已处理的节点
self.temp = set()
# 传入的节点
self.nodes = nodes
# 计算图关键点(分支、合并点)
self.key_points = None
# 关键点标记
def branch_point(self):
'''
对计算图中分支合并的点做标记,0表示分支点,1表示合并点, 2 表示既是分支点也是合并点
:return: 包含所有关键点的字典 -> Dict
'''
points = {}
for i, v in self.nodes.items():
# 分支/合并
if v['users'] > 1 and len(v['args']) > 1:
points[i] = 2
# 分支点
if v['users'] > 1 and len(v['args']) == 1:
points[i] = 0
# 合并点
if v['users'] == 1 and len(v['args']) > 1:
points[i] = 1
keys = list(points.keys())
while 'cat' not in keys[-1]:
del points[keys.pop()]
nodes_keys = list(self.nodes.keys())
del_nodes = nodes_keys[nodes_keys.index(keys[-1]) + 1:]
# 去除计算图末尾节点, 方便后续将末端代码打包到一起
for del_node in del_nodes:
del self.nodes[del_node]
self.key_points = points
def bb(self, head):
'''
从分支点开始的传播方式
:param head: 开始节点
:return: 结束节点 -> str
'''
end, _end = None, None
self.seq += 1
for i, v in self.nodes.items():
if head in v['args']:
end, link = self.nodes_forward(i, delete=True)
if self.key_points[end] != 0:
self.blocks.append({self.seq: link})
else: # 分支嵌套分支的情况
flag = 0
dic = dict()
dic[self.seq] = []
dic[self.seq].append({flag: link})
flag += 1
for m, n in self.nodes.items():
if end in n['args']:
_end, link = self.nodes_forward(m, delete=True)
dic[self.seq].append({flag: link})
end = _end
flag += 1
# 此时end为合并点
for x, y in self.nodes.items():
if end == x:
end, link = self.nodes_forward(x, branch_forward=True)
dic[self.seq].append({flag: link})
self.blocks.append(dic)
return end
def nodes_split(self, head):
'''
通过传播方式的节点划分方法,遇到关键点则该条链路传播结束
:param head: 开始节点
:return: 结束节点 -> str
'''
end = None
if head != self.head and head not in self.key_points.keys():
raise ValueError("Param Error")
if head == self.head:
end, link = self.nodes_forward(head)
self.blocks.append({self.seq: link})
elif self.key_points[head] == 0:
end = self.bb(head)
elif self.key_points[head] == 2:
self.seq += 1
for i, v in self.nodes.items():
if head == i:
end, link = self.nodes_forward(i, double=True)
self.blocks.append({self.seq: link})
end = self.bb(end)
elif self.key_points[head] == 1:
self.seq += 1
for i, v in self.nodes.items():
if head == i:
end, link = self.nodes_forward(i, branch_forward=True)
self.blocks.append({self.seq: link})
return end
def refactor(self):
'''
全图传播
:return: None
'''
end = self.nodes_split(self.head)
while True:
try:
end = self.nodes_split(end)
print(end)
if end not in self.temp:
self.temp.add(end)
else:
print("???")
break
except:
print("--over--")
break
# 结点传播
def nodes_forward(self, start, delete=False, double=False, branch_forward=False):
'''
:return: None
'''
linklist = Linklist()
# 添加链表头
linklist.add(start)
if double:
return start, linklist
# 开始向下传播
if branch_forward:
for i, v in self.nodes.items():
if start in v['args']:
if i not in self.key_points:
start = i
linklist.add(start)
break
else:
if self.key_points[i] == 0:
start = i
linklist.add(start)
break
while True:
for i, v in self.nodes.items():
if start in v['args'] and start not in self.key_points:
start = i
linklist.add(start)
break
else:
break
flag = True
for i, v in self.nodes.items():
if i == start:
if v['users'] > 1:
flag = False
if len(v['args']) > 1:
if delete:
linklist.remove(start)
flag = False
if not flag:
break
return start, linklist
def get_blocks(self):
'''
获取所有的模型分支链路
:return: List
'''
# 标记关键点
self.branch_point()
# 执行节点划分方法
self.refactor()
# print(self.key_points["cat_2"])
# print(self.temp)
return self.blocks