-
Notifications
You must be signed in to change notification settings - Fork 38
/
TD3.py
237 lines (176 loc) · 8.73 KB
/
TD3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from logger import logger
from logger import create_stats_ordered_dict
# Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
# Paper: https://arxiv.org/abs/1802.09477
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = self.max_action * torch.tanh(self.l3(x))
return x
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
# Q1 architecture
self.l1 = nn.Linear(state_dim + action_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
# Q2 architecture
self.l4 = nn.Linear(state_dim + action_dim, 400)
self.l5 = nn.Linear(400, 300)
self.l6 = nn.Linear(300, 1)
def forward(self, x, u):
xu = torch.cat([x, u], 1)
x1 = F.relu(self.l1(xu))
x1 = F.relu(self.l2(x1))
x1 = self.l3(x1)
x2 = F.relu(self.l4(xu))
x2 = F.relu(self.l5(x2))
x2 = self.l6(x2)
return x1, x2
def Q1(self, x, u):
xu = torch.cat([x, u], 1)
x1 = F.relu(self.l1(xu))
x1 = F.relu(self.l2(x1))
x1 = self.l3(x1)
return x1
class VAE(nn.Module):
def __init__(self, state_dim, action_dim, latent_dim, max_action):
super(VAE, self).__init__()
self.e1 = nn.Linear(state_dim + action_dim, 750)
self.e2 = nn.Linear(750, 750)
self.mean = nn.Linear(750, latent_dim)
self.log_std = nn.Linear(750, latent_dim)
self.d1 = nn.Linear(state_dim + latent_dim, 750)
self.d2 = nn.Linear(750, 750)
self.d3 = nn.Linear(750, action_dim)
self.max_action = max_action
self.latent_dim = latent_dim
# self.vae_threshold = nn.Linear(action_dim, action_dim)
def forward(self, state, action):
z = F.relu(self.e1(torch.cat([state, action], 1)))
z = F.relu(self.e2(z))
mean = self.mean(z)
# Clamped for numerical stability
log_std = self.log_std(z).clamp(-4, 15)
std = torch.exp(log_std)
z = mean + std * torch.FloatTensor(np.random.normal(0, 1, size=(std.size()))).to(device)
u = self.decode(state, z)
return u, mean, std
def decode_softplus(self, state, z=None):
if z is None:
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), self.latent_dim))).to(device).clamp(-0.5, 0.5)
a = F.relu(self.d1(torch.cat([state, z], 1)))
a = F.relu(self.d2(a))
def decode(self, state, z=None):
# When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
if z is None:
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), self.latent_dim))).to(device).clamp(-0.5, 0.5)
a = F.relu(self.d1(torch.cat([state, z], 1)))
a = F.relu(self.d2(a))
return self.max_action * torch.tanh(self.d3(a))
def decode_multiple(self, state, z=None, num_decode=10):
"""Decode 10 samples atleast"""
if z is None:
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), num_decode, self.latent_dim))).to(device).clamp(-0.5, 0.5)
a = F.relu(self.d1(torch.cat([state.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
a = F.relu(self.d2(a))
return self.max_action * torch.tanh(self.d3(a)), self.d3(a)
class TD3(object):
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
latent_dim = action_dim * 2
self.vae = VAE(state_dim, action_dim, latent_dim, max_action).to(device)
self.vae_optimizer = torch.optim.Adam(self.vae.parameters())
self.max_action = max_action
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
return self.actor(state).cpu().data.numpy().flatten()
def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
for it in range(iterations):
# Sample replay buffer
x, y, u, r, d, mask = replay_buffer.sample(batch_size)
state = torch.FloatTensor(x).to(device)
action = torch.FloatTensor(u).to(device)
next_state = torch.FloatTensor(y).to(device)
done = torch.FloatTensor(1 - d).to(device)
reward = torch.FloatTensor(r).to(device)
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
# Select action according to policy and add clipped noise
noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
noise = noise.clamp(-noise_clip, noise_clip)
next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
# Compute the target Q value
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (done * discount * target_Q).detach()
# Get current Q estimates
current_Q1, current_Q2 = self.critic(state, action)
std_loss = torch.std(
torch.cat([current_Q1.unsqueeze(0), current_Q2.unsqueeze(0)], 0), dim=0, unbiased=False,
)
# Compute critic loss
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy updates
if it % policy_freq == 0:
# Compute actor loss
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
sampled_actions = self.vae.decode(state)
actor_actions = self.actor(state)
action_divergence = ((sampled_actions - actor_actions)**2).sum(-1)
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
logger.record_dict(create_stats_ordered_dict(
'Q_target',
target_Q.cpu().data.numpy(),
))
logger.record_tabular('Actor Loss', actor_loss.cpu().data.numpy())
logger.record_tabular('Critic Loss', critic_loss.cpu().data.numpy())
logger.record_tabular('Std Loss', std_loss.cpu().data.numpy().mean())
logger.record_dict(create_stats_ordered_dict(
'Action_Divergence',
action_divergence.cpu().data.numpy()
))
def save(self, filename, directory):
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
def load(self, filename, directory):
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))