Skip to content

Commit

Permalink
delta_log_pi clip for DQNX
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Dec 13, 2024
1 parent 463803f commit 5fe4220
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
59 changes: 36 additions & 23 deletions alf/algorithms/dqnx_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class DQNXInfo(NamedTuple):
discount: Tensor = ()
action_distribution: td.Distribution = ()
v_values: Tensor = () # [B, q_dim]
q_values: Tensor = () # [B, num_critic_replicas, q_dim]
q_values: Tensor = () # [B, num_replicas, q_dim]
target_q_values: Tensor = () # [B, q_dim]
entropy: Tensor = () # [B]
log_pi: Tensor = () # [B]
rollout_log_pi: Tensor = () # [B]


@alf.configurable
Expand All @@ -55,11 +56,12 @@ def __init__(self,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
q_network_ctor=QNetwork,
num_critic_replicas=1,
entropy_regularization=0.03,
alpha=0.9,
num_replicas=1,
entropy_regularization=0.3,
alpha=0.99,
use_entropy_reward=True,
log_pi_clip=-1.0,
delta_log_pi_clip=0.2,
gamma=0.99,
td_lambda=0.95,
td_error_loss_fn=losses.element_wise_squared_loss,
Expand All @@ -84,10 +86,11 @@ def __init__(self,
q_network_ctor (Callable): is used to construct QNetwork for estimating ``Q(s,a)``
given that the action is discrete. Its output spec must be consistent with
the discrete action in ``action_spec``.
num_critic_replicas=1,
num_replicas=1,
entropy_regularization=0.03,
alpha=0.9,
log_pi_clip=-1.0,
delta_log_pi_clip:
gamma=0.99,
td_lambda=0.95,
td_error_loss_fn=losses.element_wise_squared_loss,
Expand All @@ -114,7 +117,7 @@ def __init__(self,
assert action_spec.shape == ()
assert entropy_regularization > 0, "Not supported"

self._num_critic_replicas = num_critic_replicas
self._num_replicas = num_replicas
self._q_dim = reward_spec.numel + 1 # one for entropy part
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
Expand All @@ -125,7 +128,7 @@ def __init__(self,
original_observation_spec = observation_spec
q_network = q_network_ctor(
input_tensor_spec=observation_spec, action_spec=action_spec)
q_networks = q_network.make_parallel(num_critic_replicas * self._q_dim)
q_networks = q_network.make_parallel(num_replicas * self._q_dim)

train_state_spec = DQNXState(q=q_networks.state_spec)
super().__init__(
Expand Down Expand Up @@ -153,24 +156,25 @@ def __init__(self,
reward_weights[:self._reward_spec.numel] = self._reward_weights
self._reward_weights = reward_weights
self._use_entropy_reward = use_entropy_reward
self._delta_log_pi_clip = delta_log_pi_clip

def _compute_q_values(self, observation, state):
"""
Returns:
- q_values: [B, num_critic_replicas, num_actions, q_dim]
- q_values: [B, num_replicas, num_actions, q_dim]
- min_q_values: min q_values across replicas, [B, num_actions, q_dim]
- action_dist:
- state: the updated state
"""
# [B, num_replicas * q_dim, num_actions]
q_values, state = self._q_networks(observation, state)
q_values = q_values.reshape(
q_values.size(0), self._num_critic_replicas, self._q_dim, -1)
# [B, num_critic_replicas, num_actions, q_dim]
q_values.size(0), self._num_replicas, self._q_dim, -1)
# [B, num_replicas, num_actions, q_dim]
q_values = q_values.transpose(2, 3)

if self._num_critic_replicas == 1:
if self._num_replicas == 1:
min_q_values = q_values[:, 0, :, :]
elif self.has_multidim_reward():
sign = self.reward_weights.sign()
Expand Down Expand Up @@ -215,11 +219,8 @@ def rollout_step(self, inputs: TimeStep, state: DQNXState):
v_values = torch.einsum('bar,ba->br', q_values, action_dist.probs)

entropy = action_dist.entropy()

log_pi = ()
if self._alpha > 0:
B = torch.arange(action.shape[0])
log_pi = action_dist.logits[B, action]
B = torch.arange(action.shape[0])
log_pi = action_dist.logits[B, action]

return AlgStep(
output=action,
Expand All @@ -230,6 +231,7 @@ def rollout_step(self, inputs: TimeStep, state: DQNXState):
reward=inputs.reward,
step_type=inputs.step_type,
discount=inputs.discount,
q_values=q_values,
v_values=v_values,
entropy=entropy,
log_pi=log_pi))
Expand All @@ -241,6 +243,7 @@ def train_step(self, inputs: TimeStep, state: DQNXState,
action = rollout_info.action
B = torch.arange(action.shape[0])
action_q_values = q_values[B, :, action]
log_pi = action_dist.logits[B, action]
return AlgStep(
output=action,
state=DQNXState(q=new_q_state),
Expand All @@ -249,14 +252,25 @@ def train_step(self, inputs: TimeStep, state: DQNXState,
step_type=rollout_info.step_type,
entropy=rollout_info.entropy,
q_values=action_q_values,
log_pi=log_pi,
rollout_log_pi=rollout_info.log_pi,
target_q_values=rollout_info.target_q_values))

def calc_loss(self, info: DQNXInfo):
q_values = info.q_values # [T, B, num_critic_replicas, q_dim]
# [T, B, num_critic_replicas, q_dim]
q_values = info.q_values # [T, B, num_replicas, q_dim]
# [T, B, num_replicas, q_dim]
target_q_values = info.target_q_values[:, :, None, :].expand_as(
q_values)
td_error = target_q_values - q_values
delta_log_pi = info.log_pi - info.rollout_log_pi
delta_log_pi = delta_log_pi[:, :, None, None]
clipped = ((delta_log_pi > self._delta_log_pi_clip) &
(td_error > 0)) | (
(delta_log_pi < -self._delta_log_pi_clip) &
(td_error < 0))
loss = self._td_error_loss_fn(target_q_values, q_values)
loss = loss * ~clipped
loss = loss.reshape(*loss.shape[:2], -1).mean(-1)

if self._debug_summaries and alf.summary.should_record_summaries():
mask = info.step_type != StepType.LAST
Expand All @@ -270,18 +284,17 @@ def _summarize(v, r, td, suffix):
safe_mean_hist_summary('returns' + suffix, r, mask)
safe_mean_hist_summary("critic_td_error" + suffix, td,
mask)
alf.summary.scalar("td_error_clip_ratio",
clipped.sum() / clipped.numel())

num_critic_replicas = q_values.size(2)
for r in range(num_critic_replicas):
num_replicas = q_values.size(2)
for r in range(num_replicas):
for i in range(q_values.size(3)):
suffix = f'/replica_{r}/{i}'
_summarize(q_values[..., r, i],
target_q_values[..., r, i],
td_error[..., r, i], suffix)

loss = self._td_error_loss_fn(target_q_values, q_values)
loss = loss.reshape(*loss.shape[:2], -1).mean(-1)

return LossInfo(
loss=loss, extra={
'critic': loss,
Expand Down
7 changes: 4 additions & 3 deletions alf/examples/dqnx_cart_pole_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
alf.config(
'DQNXAlgorithm',
epsilon_greedy=0.01,
entropy_regularization=0.01,
alpha=0.9,
entropy_regularization=0.1,
alpha=1.0,
use_entropy_reward=False,
log_pi_clip=0,
delta_log_pi_clip=0.2,
q_network_ctor=QNetwork,
optimizer=alf.optimizers.Adam(lr=1e-3))

Expand All @@ -61,7 +62,7 @@
num_checkpoints=5,
whole_replay_buffer_training=True,
clear_replay_buffer=True,
evaluate=False,
evaluate=True,
eval_interval=50,
confirm_checkpoint_upon_crash=False,
debug_summaries=True,
Expand Down

0 comments on commit 5fe4220

Please sign in to comment.