forked from seungeunrho/RLfrombasics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ch8_DQN.py
114 lines (92 loc) · 3.32 KB
/
ch8_DQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import collections
import random
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Hyperparameters
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32
class ReplayBuffer():
def __init__(self):
self.buffer = collections.deque(maxlen=buffer_limit)
def put(self, transition):
self.buffer.append(transition)
def sample(self, n):
mini_batch = random.sample(self.buffer, n)
s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
for transition in mini_batch:
s, a, r, s_prime, done_mask = transition
s_lst.append(s)
a_lst.append([a])
r_lst.append([r])
s_prime_lst.append(s_prime)
done_mask_lst.append([done_mask])
return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
torch.tensor(done_mask_lst)
def size(self):
return len(self.buffer)
class Qnet(nn.Module):
def __init__(self):
super(Qnet, self).__init__()
self.fc1 = nn.Linear(4, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def sample_action(self, obs, epsilon):
out = self.forward(obs)
coin = random.random()
if coin < epsilon:
return random.randint(0,1)
else :
return out.argmax().item()
def train(q, q_target, memory, optimizer):
for _ in range(10):
s,a,r,s_prime,done_mask = memory.sample(batch_size)
q_out = q(s)
q_a = q_out.gather(1,a)
max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
target = r + gamma * max_q_prime * done_mask
loss = F.smooth_l1_loss(q_a, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def main():
env = gym.make('CartPole-v1')
q = Qnet()
q_target = Qnet()
q_target.load_state_dict(q.state_dict())
memory = ReplayBuffer()
print_interval = 20
score = 0.0
optimizer = optim.Adam(q.parameters(), lr=learning_rate)
for n_epi in range(10000):
epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) # Linear annealing from 8% to 1%
s, _ = env.reset()
done = False
truncated = False
while not (done or truncated):
a = q.sample_action(torch.from_numpy(s).float(), epsilon)
s_prime, r, done, truncated, _ = env.step(a)
done_mask = 0.0 if done else 1.0
memory.put((s, a, r/100.0, s_prime, done_mask))
s = s_prime
score += r
if memory.size()>2000:
train(q, q_target, memory, optimizer)
if n_epi%print_interval==0 and n_epi!=0:
q_target.load_state_dict(q.state_dict())
print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
n_epi, score/print_interval, memory.size(), epsilon*100))
score = 0.0
env.close()
if __name__ == '__main__':
main()