forked from seungeunrho/RLfrombasics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ch6_SARSA.py
43 lines (31 loc) · 1.02 KB
/
ch6_SARSA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
from world import AgentBase, GridWorld2
class QAgent(AgentBase):
def __init__(self):
super().__init__()
self.alpha = 0.1
def update_table(self, transition):
s, a, r, s_prime = transition
x, y = s
next_x, next_y = s_prime
a_prime = self.select_action(s_prime) # S'에서 선택할 액션 (실제로 취한 액션이 아님)
# SARSA 업데이트 식을 이용
self.q_table[x,y,a] += self.alpha * (r + self.q_table[next_x,next_y,a_prime] - self.q_table[x,y,a])
def anneal_eps(self):
self.eps -= 0.03
self.eps = max(self.eps, 0.1)
def main():
env = GridWorld2()
agent = QAgent()
for _ in range(1000):
done = False
s = env.reset()
while not done:
a = agent.select_action(s)
s_prime, r, done = env.step(a)
agent.update_table((s,a,r,s_prime))
s = s_prime
agent.anneal_eps()
agent.show_table()
if __name__ == '__main__':
main()