forked from ailec0623/DQN_HollowKnight
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DQN.py
169 lines (133 loc) · 9.31 KB
/
DQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import tensorflow as tf
import numpy as np
class DQN:
def __init__(self,model,gamma=0.9,learnging_rate=0.0001):
self.model = model
self.act_dim = model.act_dim
self.act_model = model.act_model
self.move_model = model.move_model
self.gamma = gamma
self.lr = learnging_rate
# --------------------------训练模型--------------------------- #
self.act_model.optimizer = tf.optimizers.Adam(learning_rate=self.lr)
self.act_model.loss_func = tf.losses.MeanSquaredError()
self.move_model.optimizer = tf.optimizers.Adam(learning_rate=self.lr)
self.move_model.loss_func = tf.losses.MeanSquaredError()
# self.act_model.train_loss = tf.metrics.Mean(name="train_loss")
# ------------------------------------------------------------ #
self.act_global_step = 0
self.move_global_step = 0
self.update_target_steps = 100 # 每隔200个training steps再把model的参数复制到target_model中
# train functions for act model
def act_predict(self, obs):
""" 使用self.act_model的value网络来获取 [Q(s,a1),Q(s,a2),...]
"""
return self.act_model.predict(obs)
def act_train_step(self,action,features,labels):
""" 训练步骤
"""
with tf.GradientTape() as tape:
# 计算 Q(s,a) 与 target_Q的均方差,得到loss
predictions = self.act_model(features,training=True)
enum_action = list(enumerate(action))
pred_action_value = tf.gather_nd(predictions,indices=enum_action)
loss = self.act_model.loss_func(labels,pred_action_value)
gradients = tape.gradient(loss,self.act_model.trainable_variables)
self.act_model.optimizer.apply_gradients(zip(gradients,self.act_model.trainable_variables))
self.model.act_loss.append(loss)
# self.act_model.train_loss.update_state(loss)
def act_train_model(self,action,features,labels,epochs=1):
""" 训练模型
"""
for epoch in tf.range(1,epochs+1):
self.act_train_step(action,features,labels)
def act_learn(self,obs,action,reward,next_obs,terminal):
""" 使用DQN算法更新self.act_model的value网络
"""
# print('learning')
# 每隔200个training steps同步一次model和target_model的参数
# if self.act_global_step % self.update_target_steps == 0:
# self.act_replace_target()
# 从target_model中获取 max Q' 的值,用于计算target_Q
self.act_train_model(action,obs,reward,epochs=1)
self.act_global_step += 1
# print('finish')
def act_replace_target(self):
'''预测模型权重更新到target模型权重'''
for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=0).get_layer(index=0).get_layers()):
l.set_weights(self.act_model.get_layer(index=1).get_layer(index=0).get_layer(index=0).get_layer(index=i).get_weights())
for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=0).get_layer(index=1).get_layers()):
l.set_weights(self.act_model.get_layer(index=1).get_layer(index=0).get_layer(index=1).get_layer(index=i).get_weights())
# for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=1).get_layer(index=0).get_layers()):
# l.set_weights(self.act_model.get_layer(index=1).get_layer(index=1).get_layer(index=0).get_layer(index=i).get_weights())
# for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=1).get_layer(index=1).get_layers()):
# l.set_weights(self.act_model.get_layer(index=1).get_layer(index=1).get_layer(index=1).get_layer(index=i).get_weights())
# for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=2).get_layer(index=0).get_layers()):
# l.set_weights(self.act_model.get_layer(index=1).get_layer(index=2).get_layer(index=0).get_layer(index=i).get_weights())
# for i, l in enumerate(self.act_target_model.get_layer(index=1).get_layer(index=2).get_layer(index=1).get_layers()):
# l.set_weights(self.act_model.get_layer(index=1).get_layer(index=2).get_layer(index=1).get_layer(index=i).get_weights())
self.act_target_model.get_layer(index=1).get_layer(index=2).set_weights(self.act_model.get_layer(index=1).get_layer(index=2).get_weights())
# self.act_target_model.get_layer(index=1).get_layer(index=6).set_weights(self.act_model.get_layer(index=1).get_layer(index=6).get_weights())
# train functions for move_model
def move_predict(self, obs):
""" 使用self.move_model的value网络来获取 [Q(s,a1),Q(s,a2),...]
"""
return self.move_model.predict(obs)
def move_train_step(self,action,features,labels):
""" 训练步骤
"""
with tf.GradientTape() as tape:
# 计算 Q(s,a) 与 target_Q的均方差,得到loss
predictions = self.move_model(features,training=True)
enum_action = list(enumerate(action))
pred_action_value = tf.gather_nd(predictions,indices=enum_action)
loss = self.move_model.loss_func(labels,pred_action_value)
gradients = tape.gradient(loss,self.move_model.trainable_variables)
self.move_model.optimizer.apply_gradients(zip(gradients,self.move_model.trainable_variables))
self.model.move_loss.append(loss)
# self.move_plot_loss()
# print("Move loss: ", loss)
# self.move_model.train_loss.update_state(loss)
def move_train_model(self,action,features,labels,epochs=1):
""" 训练模型
"""
for epoch in tf.range(1,epochs+1):
self.move_train_step(action,features,labels)
def move_learn(self,obs,action,reward,next_obs,terminal):
""" 使用DQN算法更新self.move_model的value网络
"""
self.move_train_model(action,obs,reward,epochs=1)
self.move_global_step += 1
def move_replace_target(self):
'''预测模型权重更新到target模型权重'''
for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=0).get_layer(index=0).get_layers()):
l.set_weights(self.move_model.get_layer(index=1).get_layer(index=0).get_layer(index=0).get_layer(index=i).get_weights())
for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=0).get_layer(index=1).get_layers()):
l.set_weights(self.move_model.get_layer(index=1).get_layer(index=0).get_layer(index=1).get_layer(index=i).get_weights())
# for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=1).get_layer(index=0).get_layers()):
# l.set_weights(self.move_model.get_layer(index=1).get_layer(index=1).get_layer(index=0).get_layer(index=i).get_weights())
# for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=1).get_layer(index=1).get_layers()):
# l.set_weights(self.move_model.get_layer(index=1).get_layer(index=1).get_layer(index=1).get_layer(index=i).get_weights())
# for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=2).get_layer(index=0).get_layers()):
# l.set_weights(self.move_model.get_layer(index=1).get_layer(index=2).get_layer(index=0).get_layer(index=i).get_weights())
# for i, l in enumerate(self.move_target_model.get_layer(index=1).get_layer(index=2).get_layer(index=1).get_layers()):
# l.set_weights(self.move_model.get_layer(index=1).get_layer(index=2).get_layer(index=1).get_layer(index=i).get_weights())
self.move_target_model.get_layer(index=1).get_layer(index=2).set_weights(self.move_model.get_layer(index=1).get_layer(index=2).get_weights())
# self.move_target_model.get_layer(index=1).get_layer(index=6).set_weights(self.move_model.get_layer(index=1).get_layer(index=6).get_weights())
def replace_target(self):
# print("replace target")
# copy conv3d_1
self.model.shared_target_model.get_layer(index=0).set_weights(self.model.shared_model.get_layer(index=0).get_weights())
# copy batchnormalization_1
self.model.shared_target_model.get_layer(index=1).set_weights(self.model.shared_model.get_layer(index=1).get_weights())
# copy shard_resnet block
for i, l in enumerate(self.model.shared_target_model.get_layer(index=4).get_layer(index=0).get_layers()):
l.set_weights(self.model.shared_model.get_layer(index=4).get_layer(index=0).get_layer(index=i).get_weights())
for i, l in enumerate(self.model.shared_target_model.get_layer(index=4).get_layer(index=1).get_layers()):
l.set_weights(self.model.shared_model.get_layer(index=4).get_layer(index=1).get_layer(index=i).get_weights())
for i, l in enumerate(self.model.shared_target_model.get_layer(index=5).get_layer(index=0).get_layers()):
l.set_weights(self.model.shared_model.get_layer(index=5).get_layer(index=0).get_layer(index=i).get_weights())
for i, l in enumerate(self.model.shared_target_model.get_layer(index=5).get_layer(index=1).get_layers()):
l.set_weights(self.model.shared_model.get_layer(index=5).get_layer(index=1).get_layer(index=i).get_weights())
self.move_replace_target()
self.act_replace_target()