-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_network_trainer.py
228 lines (154 loc) · 6.87 KB
/
q_network_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import q_network
import cube
# This trainer should be state, action, env agnostic
# Hyper parameters can be specified, but environment should be a passable parameter
# For DQN/DNN/RNNs, think of embedding the states using something like word2vec - this might help group states.
# Examples - (You could even do multitask)
# train an embedding to predict how far away from solved state.
# train an embedding to predict the middle state between two states
# train an embedding to predict all the rotated forms of a cube state
## Steps for my understanding
# Start with an epsilon greedy approach (hyperparameter), slowly prune it down
# Have a stochastic policy function PI which takes in a Q-table and an epsilon
# Q(s, a) = Q(s, a) + learning_rate * (r(s, a) + discount * value(s')_PI - Q(s, a))
## HYPER-PARAMETERS ##
PRINT_DIAG = True
LOAD_NETWORK = True
SAVE_NETWORK = True
UPDATE_NETWORK = True
LOAD_NETWORK_NAME = "Q_Network_2STEP_5000_episodes.pickle"
SAVE_NETWORK_NAME = "Q_Network_2STEP_20000_episodes.pickle"
## Q-network specific parameters ##
EPSILON_START = 0.0 # Explore 50% of the time
EPSILON_END = 0.05 # Greedy policy
N_EPISODES = 15000
LEARNING_RATE = 0.01 # Q-Network trainer learning_rate
DISCOUNT_FACTOR = 1 # Q-Network trainer discount_factor
###################################
## Cube specific parameters ##
SIDE = 2
N_MOVES = 2
##############################
## Environment to Q-Network interface parameters ##
N_ACTIONS = 3*SIDE*2
# Set n_steps_episode_max = -1 for an episode structure which doesnt terminate in fixed moves count but terminates in a state
N_STEPS_EPISODE_MAX = N_MOVES # Max number of steps taken per episode.
###################################################
## Q Network specific Functions ##
##################################
## Ideally all env specific functions would be in the env object
## Interface between Q Network and Environment ##
def reset_env():
## Set Parameters ##
side = SIDE
n_moves = N_MOVES
####################
return cube.CubeObject(side, n_moves)
#################################################
## Q Network specific functions ##
def run_episode(q_network_obj, episode, epsilon=1):
# Set Parameters ##
n_episodes = N_EPISODES
n_steps_episode_max = N_STEPS_EPISODE_MAX
###################
if PRINT_DIAG is True:
n_states_start = q_network_obj.n_states
n_states_new = 0
env = reset_env()
obs = env.get_observation()
#if (obs in q_network_obj.state_obs_list) is False:
if q_network.get_existance(q_network_obj.state_obs_list, obs, method=env.method_existence) is False:
# New observation - add it
q_network_obj.add_state(obs)
if PRINT_DIAG is True:
n_states_new = n_states_new + 1
terminate_episode = False
steps_taken = 0
state_action_reward_list = []
while True:
state_obs_list = q_network_obj.state_obs_list
#state_idx = state_obs_list.index(obs)
state_idx = q_network.get_index(state_obs_list, obs, method_existence=env.method_existence,
method_index=env.method_index)
qtable = q_network_obj.qtable
action_idx = q_network.policy(qtable, state_idx, epsilon)
terminate_episode = env.is_terminal_state()
if n_steps_episode_max != -1 and steps_taken >= n_steps_episode_max:
terminate_episode = True
if terminate_episode is True:
break
obs, reward = env.apply_action(action_idx)
#if (obs in state_obs_list) is False:
if q_network.get_existance(state_obs_list, obs, method=env.method_existence) is False:
# New observation - add it
q_network_obj.add_state(obs)
state_obs_list = q_network_obj.state_obs_list
if PRINT_DIAG is True:
n_states_new = n_states_new + 1
#state_idx_new = state_obs_list.index(obs)
state_idx_new = q_network.get_index(state_obs_list, obs, method_existence=env.method_existence, method_index=env.method_index)
state_action_reward_list += [[state_idx, state_idx_new, action_idx, reward]]
steps_taken = steps_taken + 1
if PRINT_DIAG is True:
print("episode number: {}, number of existing states: {}, number of newly added states: {}".format(episode, n_states_start, n_states_new))
return q_network_obj, state_action_reward_list
def update_q_table(q_network_obj, state_action_reward_list, alpha, gamma):
q_table_old = q_network_obj.qtable
q_table_new = q_table_old.copy()
if PRINT_DIAG is True:
n_updates = 0
total_update_abs = 0
for i in range(len(state_action_reward_list)):
item = state_action_reward_list[i]
s = item[0]
s_ = item[1]
a = item[2]
r = item[3]
q_table_new[s][a] = q_table_old[s][a] + alpha * (r + gamma * max(q_table_old[s_]) - q_table_old[s][a])
if PRINT_DIAG is True:
if q_table_old[s][a] != 0:
n_updates = n_updates + 1
total_update_abs = total_update_abs + abs((q_table_new[s][a] - q_table_old[s][a])/q_table_old[s][a])
if n_updates > 0:
print("Training stats: number of existing state updates: {}, mean percentage change of weights: {}".format(
n_updates, 100.0 * total_update_abs / n_updates))
else:
# No Updates this episodes
print("Training stats: number of existing state updates: {}, mean percentage change of weights: {}".format(
n_updates, "N/A"))
q_network_obj.update_q_table(q_table_new)
def run():
# Main Q-Network former and trainer
## Set Parameters ##
n_actions = N_ACTIONS
n_episodes = N_EPISODES
epsilon_start = EPSILON_START
epsilon_end = EPSILON_END
alpha_default = LEARNING_RATE
gamma_default = DISCOUNT_FACTOR
load_network = LOAD_NETWORK
save_network = SAVE_NETWORK
update_network = UPDATE_NETWORK
load_network_name = LOAD_NETWORK_NAME
save_network_name = SAVE_NETWORK_NAME
####################
# Create a q_network with zero states observed
if load_network is False:
q_network_obj = q_network.q_network(n_states=0, n_actions=n_actions)
else:
q_network_obj = q_network.get_q_network_obj(load_network_name)
if update_network is True:
# Run episodes
for episode in range(n_episodes):
epsilon = q_network.epsilon_control_algo(epsilon_start, epsilon_end, n_episodes, episode)
gamma = gamma_default
if episode == 0:
alpha = 1
else:
alpha = alpha_default
q_network_obj, state_action_reward_list = run_episode(q_network_obj, episode, epsilon)
update_q_table(q_network_obj, state_action_reward_list, alpha, gamma)
if save_network is True:
q_network.dump_q_network(q_network_obj, save_network_name)
if __name__ == "__main__":
run()