forked from seungeunrho/RLfrombasics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ch6_QLearning.py
41 lines (30 loc) · 1006 Bytes
/
ch6_QLearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from world import AgentBase, GridWorld2
class QAgent(AgentBase):
def __init__(self):
super().__init__()
self.alpha = 0.1
def update_table(self, transition):
s, a, r, s_prime = transition
x, y = s
next_x, next_y = s_prime
# Q러닝 업데이트 식을 이용
self.q_table[x,y,a] += self.alpha * (r + np.amax(self.q_table[next_x,next_y,:]) - self.q_table[x,y,a])
def anneal_eps(self):
self.eps -= 0.01 # Q러닝에선 epsilon 이 좀더 천천히 줄어 들도록 함.
self.eps = max(self.eps, 0.2)
def main():
env = GridWorld2()
agent = QAgent()
for _ in range(1000):
done = False
s = env.reset()
while not done:
a = agent.select_action(s)
s_prime, r, done = env.step(a)
agent.update_table((s,a,r,s_prime))
s = s_prime
agent.anneal_eps()
agent.show_table()
if __name__ == '__main__':
main()