-
Notifications
You must be signed in to change notification settings - Fork 0
/
prior_sac_agent.py
154 lines (129 loc) · 7.5 KB
/
prior_sac_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import numpy as np
from spirl.rl.agents.ac_agent import SACAgent
from spirl.utils.general_utils import ParamDict, ConstantSchedule, AttrDict
from spirl.utils.pytorch_utils import check_shape, map2torch
class ActionPriorSACAgent(SACAgent):
"""Implements SAC with non-uniform, learned action / skill prior."""
def __init__(self, config):
SACAgent.__init__(self, config)
self._target_divergence = self._hp.td_schedule(self._hp.td_schedule_params)
def _default_hparams(self):
default_dict = ParamDict({
'alpha_min': None, # minimum value alpha is clipped to, no clipping if None
'td_schedule': ConstantSchedule, # schedule used for target divergence param
'td_schedule_params': AttrDict( # parameters for target divergence schedule
p=1.,
),
})
return super()._default_hparams().overwrite(default_dict)
def update(self, experience_batch):
info = super().update(experience_batch)
info.target_divergence = self._target_divergence(self.schedule_steps)
return info
def _compute_alpha_loss(self, policy_output):
"""Computes loss for alpha update based on target divergence."""
return self.alpha * (self._target_divergence(self.schedule_steps) - policy_output.prior_divergence).detach().mean()
def _compute_policy_loss(self, experience_batch, policy_output):
"""Computes loss for policy update."""
q_est = torch.min(*[critic(experience_batch.observation, self._prep_action(policy_output.action)).q
for critic in self.critics])
policy_loss = -1 * q_est + self.alpha * policy_output.prior_divergence[:, None]
check_shape(policy_loss, [self._hp.batch_size, 1])
return policy_loss.mean()
def _compute_next_value(self, experience_batch, policy_output):
"""Computes value of next state for target value computation."""
q_next = torch.min(*[critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q
for critic_target in self.critic_targets])
next_val = (q_next - self.alpha * policy_output.prior_divergence[:, None])
check_shape(next_val, [self._hp.batch_size, 1])
return next_val.squeeze(-1)
def _aux_info(self, experience_batch, policy_output):
"""Stores any additional values that should get logged to WandB."""
aux_info = super()._aux_info(experience_batch, policy_output)
aux_info.prior_divergence = policy_output.prior_divergence.mean()
if 'ensemble_divergence' in policy_output: # when using ensemble thresholded prior divergence
aux_info.ensemble_divergence = policy_output.ensemble_divergence.mean()
aux_info.learned_prior_divergence = policy_output.learned_prior_divergence.mean()
aux_info.below_ensemble_div_thresh = policy_output.below_ensemble_div_thresh.mean()
return aux_info
def state_dict(self, *args, **kwargs):
d = super().state_dict(*args, **kwargs)
d['update_steps'] = self._update_steps
return d
def load_state_dict(self, state_dict, *args, **kwargs):
self._update_steps = state_dict.pop('update_steps')
super().load_state_dict(state_dict, *args, **kwargs)
@property
def alpha(self):
if self._hp.alpha_min is not None:
return torch.clamp(super().alpha, min=self._hp.alpha_min)
return super().alpha
class RandActScheduledActionPriorSACAgent(ActionPriorSACAgent):
"""Adds scheduled call to random action (aka prior execution) -> used if downstream policy trained from scratch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._omega = self._hp.omega_schedule(self._hp.omega_schedule_params)
def _default_hparams(self):
default_dict = ParamDict({
'omega_schedule': ConstantSchedule, # schedule used for omega param
'omega_schedule_params': AttrDict( # parameters for omega schedule
p=0.1,
),
})
return super()._default_hparams().overwrite(default_dict)
def _act(self, obs):
"""Call random action (aka prior policy) omega percent of times."""
if np.random.rand() <= self._omega(self._update_steps):
return super()._act_rand(obs)
else:
return super()._act(obs)
def update(self, experience_batch):
if 'delay' in self._hp.omega_schedule_params and self._update_steps < self._hp.omega_schedule_params.delay:
# if schedule has warmup phase in which *only* prior is sampled, train policy to minimize divergence
self.replay_buffer.append(experience_batch)
experience_batch = self.replay_buffer.sample(n_samples=self._hp.batch_size)
experience_batch = map2torch(experience_batch, self._hp.device)
policy_output = self._run_policy(experience_batch.observation)
policy_loss = policy_output.prior_divergence.mean()
self._perform_update(policy_loss, self.policy_opt, self.policy)
self._update_steps += 1
info = AttrDict(prior_divergence=policy_output.prior_divergence.mean())
else:
info = super().update(experience_batch)
info.omega = self._omega(self._update_steps)
return info
class CodebookBasedActionPriorSACAgent(ActionPriorSACAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_codebook = self._hp.codebook
def act(self, obs):
obs = map2torch(self._obs_normalizer(obs), self._hp.device)
if len(obs.shape) == 1:
output = self.policy.net._compute_output_dist(obs[None])
else:
output = self.policy.net._compute_output_dist(obs)
action_dist = torch.distributions.Categorical(output)
code_idx = action_dist.sample()
one_hot = torch.zeros((output.shape))
one_hot[:, code_idx] = 1
policy_output = self.get_codebook()[code_idx] # choose code
return AttrDict(action=policy_output, idx=code_idx, log_prob=(output + 1e-8).log(), prob=output)
def _compute_policy_loss(self, experience_batch, policy_output):
"""Computes loss for policy update."""
q_est = torch.min(*[critic(experience_batch.observation).q
for critic in self.critics])
q_value = torch.sum(policy_output.prob * q_est, dim=-1, keepdim=True)
policy_loss = -1 * q_value + self.alpha * policy_output.prior_divergence.to("cuda:0")
check_shape(policy_loss, [self._hp.batch_size, 1])
return policy_loss.mean()
def _compute_next_value(self, experience_batch, policy_output):
"""Computes value of next state for target value computation."""
q_next = torch.min(*[critic_target(experience_batch.observation_next).q.gather(1,self._prep_action(policy_output.idx).type(torch.int64).to("cuda:0"))
for critic_target in self.critic_targets])
next_val = q_next - self.alpha * policy_output.prior_divergence.to("cuda:0")
check_shape(next_val, [self._hp.batch_size, 1])
return next_val.squeeze(-1)
def _compute_q_estimates(self, experience_batch):
return [critic(experience_batch.observation).q.squeeze(-1).gather(1,self._prep_action(experience_batch.idx).type(torch.int64).to("cuda:0").detach()).squeeze()
for critic in self.critics] # no gradient propagation into policy here!