-
Notifications
You must be signed in to change notification settings - Fork 74
/
ch6_SARSA.py
130 lines (109 loc) · 3.27 KB
/
ch6_SARSA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import random
import numpy as np
class GridWorld():
def __init__(self):
self.x=0
self.y=0
def step(self, a):
# 0번 액션: 왼쪽, 1번 액션: 위, 2번 액션: 오른쪽, 3번 액션: 아래쪽
if a==0:
self.move_left()
elif a==1:
self.move_up()
elif a==2:
self.move_right()
elif a==3:
self.move_down()
reward = -1 # 보상은 항상 -1로 고정
done = self.is_done()
return (self.x, self.y), reward, done
def move_left(self):
if self.y==0:
pass
elif self.y==3 and self.x in [0,1,2]:
pass
elif self.y==5 and self.x in [2,3,4]:
pass
else:
self.y -= 1
def move_right(self):
if self.y==1 and self.x in [0,1,2]:
pass
elif self.y==3 and self.x in [2,3,4]:
pass
elif self.y==6:
pass
else:
self.y += 1
def move_up(self):
if self.x==0:
pass
elif self.x==3 and self.y==2:
pass
else:
self.x -= 1
def move_down(self):
if self.x==4:
pass
elif self.x==1 and self.y==4:
pass
else:
self.x+=1
def is_done(self):
if self.x==4 and self.y==6: # 목표 지점인 (4,6)에 도달하면 끝난다
return True
else:
return False
def reset(self):
self.x = 0
self.y = 0
return (self.x, self.y)
class QAgent():
def __init__(self):
self.q_table = np.zeros((5, 7, 4)) # 마찬가지로 Q 테이블을 0으로 초기화
self.eps = 0.9
def select_action(self, s):
# eps-greedy로 액션을 선택해준다
x, y = s
coin = random.random()
if coin < self.eps:
action = random.randint(0,3)
else:
action_val = self.q_table[x,y,:]
action = np.argmax(action_val)
return action
def update_table(self, transition):
s, a, r, s_prime = transition
x,y = s
next_x, next_y = s_prime
a_prime = self.select_action(s_prime) # S'에서 선택할 액션 (실제로 취한 액션이 아님)
# SARSA 업데이트 식을 이용
self.q_table[x,y,a] = self.q_table[x,y,a] + 0.1 * (r + self.q_table[next_x,next_y,a_prime] - self.q_table[x,y,a])
def anneal_eps(self):
self.eps -= 0.03
self.eps = max(self.eps, 0.1)
def show_table(self):
q_lst = self.q_table.tolist()
data = np.zeros((5,7))
for row_idx in range(len(q_lst)):
row = q_lst[row_idx]
for col_idx in range(len(row)):
col = row[col_idx]
action = np.argmax(col)
data[row_idx, col_idx] = action
print(data)
def main():
env = GridWorld()
agent = QAgent()
for n_epi in range(1000):
done = False
s = env.reset()
while not done:
a = agent.select_action(s)
s_prime, r, done = env.step(a)
agent.update_table((s,a,r,s_prime))
s = s_prime
agent.anneal_eps()
agent.show_table()
if __name__ == '__main__':
main()