forked from pparas007/AI-Algorithms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Sarsa_vs_Q-learning.py
138 lines (122 loc) · 4.44 KB
/
Sarsa_vs_Q-learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches
import pandas as pd
def chooseNextAction(Q, x, y, is_sarsa = True):
epsilon = 0.1
if np.random.random() < epsilon and is_sarsa:
action = np.random.choice(range(4))
else:
max_value = -np.inf
for action in range(4):
value = Q[x,y,action]
if max_value <= value:
max_value = value
max_action = action
action = max_action
return action
def chooseNextState(x, y, current_action):
x_new, y_new = x, y
if current_action == 0:
x_new = x + 1
elif current_action == 1:
x_new = x - 1
elif current_action == 2:
y_new = y + 1
elif current_action == 3:
y_new = y - 1
x_new = max(0, x_new)
x_new = min(3, x_new)
y_new = max(0, y_new)
y_new = min(11, y_new)
return x_new, y_new
def updateValueQ(Q, x, y, current_action, is_sarsa=False):
alpha = 0.1
gamma = 1
x_new, y_new = chooseNextState(x, y, current_action)
if x_new == 0 and y_new == 11:
reward = 0
elif x_new == 0 and 11 > y_new > 0:
reward = -500
else:
reward = -1
next_action = chooseNextAction(Q, x_new, y_new, is_sarsa)
Q[x, y, current_action] = Q[x, y, current_action] + alpha * float((reward + gamma * Q[x_new, y_new, next_action] - Q[x, y, current_action]))
if x_new == 0 and 11 > y_new > 0:
return 0, 0, reward
else:
return x_new, y_new, reward
def algorithm(is_sarsa=False):
actions_of_last_episode = [(0,0)]
actions_per_episode = []
rewards_per_episode = []
Q = np.zeros((4, 12, 4))
for n in range(1000):
x = 0
y = 0
actions = 0
rewards = 0
while not(x == 0 and y == 11):
current_action = chooseNextAction(Q, x, y)
x_new, y_new, reward = updateValueQ(Q, x, y, current_action, is_sarsa=is_sarsa)
x = x_new
y = y_new
actions = actions + 1
rewards = rewards + reward
if n == 999:
actions_of_last_episode.append((x_new, y_new))
if x == 0 and y == 0:
actions_of_last_episode = [(0, 0)]
actions_per_episode.append(actions)
rewards_per_episode.append(rewards)
return actions_of_last_episode, actions_per_episode, rewards_per_episode
def plot(q_steps, q_actions, q_rewards, sarsa_steps, sarsa_actions, sarsa_rewards):
fig1 = plt.figure()
previous_x = 0
previous_y = 0
for position in q_steps:
x, y = position[1], position[0]
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width = 0.1, head_length = 0.2, color = 'red')
plt.plot(x, y, 'ro', markersize=1)
previous_x = x
previous_y = y
previous_x = 0
previous_y = 0
for position in sarsa_steps:
x, y = position[1], position[0]
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width = 0.1, head_length = 0.2, color = 'blue')
plt.plot(x, y, 'bo', markersize=1)
previous_x = x
previous_y = y
plt.plot(0, 0, 'mo', markersize=15)
plt.plot(11, 0, 'go', markersize=15)
axes = plt.gca()
axes.set_xticks(range(0, 12))
axes.set_yticks(range(0, 4))
axes.set_title('Paths of Algoritms')
red_patch = patches.Patch(color = 'red', label = 'Q Learning')
blue_patch = patches.Patch(color = 'blue', label = 'Sarsa Learning')
plt.legend(handles=[red_patch, blue_patch])
plt.grid()
plt.savefig('paths_of_algoritms.png')
labels = ['Q learning','Sarsa Learning']
fig2 = plt.figure(figsize=(10, 6))
smoothing = 100
rewards_smoothed = pd.Series(q_rewards).rolling(smoothing, min_periods=smoothing).mean()
plt.plot(rewards_smoothed, 'r')
rewards_smoothed = pd.Series(sarsa_rewards).rolling(smoothing, min_periods=smoothing).mean()
plt.plot(rewards_smoothed, 'b')
plt.xlabel("Number of Episode")
plt.ylabel("Reward Per Episode")
plt.title("Reward Per Episode Over Time (Smoothed)")
plt.legend(labels)
plt.savefig('episode_reward.png')
q_steps, q_actions, q_rewards = algorithm()
print("Q-Learning Steps")
for step in q_steps:
print(step)
sarsa_steps, sarsa_actions, sarsa_rewards = algorithm(is_sarsa = True)
print("Sarsa-Learning Steps")
for step in sarsa_steps:
print(step)
plot(q_steps, q_actions, q_rewards, sarsa_steps, sarsa_actions, sarsa_rewards)