forked from dsshim0125/s2p
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gaussian_ensemble.py
106 lines (82 loc) · 3.98 KB
/
gaussian_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
def soft_clamp(x : torch.Tensor, _min=None, _max=None):
# clamp tensor values while mataining the gradient
if _max is not None:
x = _max - F.softplus(_max - x)
if _min is not None:
x = _min + F.softplus(x - _min)
return x
class EnsembleLinear(torch.nn.Module):
def __init__(self, in_features, out_features, ensemble_size=7):
super().__init__()
self.ensemble_size = ensemble_size
self.register_parameter('weight', torch.nn.Parameter(torch.zeros(ensemble_size, in_features, out_features)))
self.register_parameter('bias', torch.nn.Parameter(torch.zeros(ensemble_size, 1, out_features)))
torch.nn.init.trunc_normal_(self.weight, std=1/(2*in_features**0.5))
self.register_parameter('saved_weight', torch.nn.Parameter(self.weight.detach().clone()))
self.register_parameter('saved_bias', torch.nn.Parameter(self.bias.detach().clone()))
self.select = list(range(0, self.ensemble_size))
def forward(self, x):
weight = self.weight[self.select]
bias = self.bias[self.select]
if len(x.shape) == 2:
x = torch.einsum('ij,bjk->bik', x, weight)
else:
x = torch.einsum('bij,bjk->bik', x, weight)
x = x + bias
return x
def set_select(self, indexes):
assert len(indexes) <= self.ensemble_size and max(indexes) < self.ensemble_size
self.select = indexes
self.weight.data[indexes] = self.saved_weight.data[indexes]
self.bias.data[indexes] = self.saved_bias.data[indexes]
def update_save(self, indexes):
self.saved_weight.data[indexes] = self.weight.data[indexes]
self.saved_bias.data[indexes] = self.bias.data[indexes]
class EnsembleTransition(torch.nn.Module):
def __init__(self, obs_dim, action_dim, hidden_features, hidden_layers, ensemble_size=7, mode='local', with_reward=True):
super().__init__()
self.obs_dim = obs_dim
self.mode = mode
self.with_reward = with_reward
self.ensemble_size = ensemble_size
self.activation = Swish()
module_list = []
for i in range(hidden_layers):
if i == 0:
module_list.append(EnsembleLinear(obs_dim + action_dim, hidden_features, ensemble_size))
else:
module_list.append(EnsembleLinear(hidden_features, hidden_features, ensemble_size))
self.backbones = torch.nn.ModuleList(module_list)
self.output_layer = EnsembleLinear(hidden_features, 2 * (obs_dim + self.with_reward), ensemble_size)
self.register_parameter('max_logstd', torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * 1, requires_grad=True))
self.register_parameter('min_logstd', torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * -5, requires_grad=True))
def forward(self, obs_action):
output = obs_action
for layer in self.backbones:
output = self.activation(layer(output))
mu, logstd = torch.chunk(self.output_layer(output), 2, dim=-1)
logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
if self.mode == 'local':
if self.with_reward:
obs, reward = torch.split(mu, [self.obs_dim, 1], dim=-1)
obs = obs + obs_action[..., :self.obs_dim]
mu = torch.cat([obs, reward], dim=-1) # obs is s' (not delta s)
else:
mu = mu + obs_action[..., :self.obs_dim]
return torch.distributions.Normal(mu, torch.exp(logstd))
def set_select(self, indexes):
for layer in self.backbones:
layer.set_select(indexes)
self.output_layer.set_select(indexes)
def update_save(self, indexes):
for layer in self.backbones:
layer.update_save(indexes)
self.output_layer.update_save(indexes)