Skip to content

Commit

Permalink
dqnx debug
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Dec 16, 2024
1 parent 5fe4220 commit 3530d14
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 64 deletions.
178 changes: 117 additions & 61 deletions alf/algorithms/dqnx_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
from alf.networks import QNetwork
from alf.nest.utils import convert_device
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import losses, dist_utils, tensor_utils, value_ops
from alf.utils import common, losses, dist_utils, tensor_utils, value_ops
from alf.utils.summary_utils import safe_mean_hist_summary


class DQNXState(NamedTuple):
q: Tensor
target_q: Tensor = ()


class DQNXInfo(NamedTuple):
Expand All @@ -59,6 +60,7 @@ def __init__(self,
num_replicas=1,
entropy_regularization=0.3,
alpha=0.99,
target_update_period=1,
use_entropy_reward=True,
log_pi_clip=-1.0,
delta_log_pi_clip=0.2,
Expand Down Expand Up @@ -116,9 +118,12 @@ def __init__(self,
assert action_spec.is_discrete
assert action_spec.shape == ()
assert entropy_regularization > 0, "Not supported"
assert 0 <= alpha <= 1, f"Invalid alpha: {alpha}"

self._num_replicas = num_replicas
self._q_dim = reward_spec.numel + 1 # one for entropy part
self._q_dim = reward_spec.numel
if target_update_period == 0:
self._q_dim += 1 # one for entropy part
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
self._epsilon_greedy = epsilon_greedy
Expand All @@ -130,7 +135,9 @@ def __init__(self,
input_tensor_spec=observation_spec, action_spec=action_spec)
q_networks = q_network.make_parallel(num_replicas * self._q_dim)

train_state_spec = DQNXState(q=q_networks.state_spec)
train_state_spec = DQNXState(
q=q_networks.state_spec,
target_q=q_network.state_spec if target_update_period > 0 else ())
super().__init__(
observation_spec=original_observation_spec,
action_spec=action_spec,
Expand Down Expand Up @@ -158,6 +165,24 @@ def __init__(self,
self._use_entropy_reward = use_entropy_reward
self._delta_log_pi_clip = delta_log_pi_clip

self._target_q_networks = None
if target_update_period > 0:
self._target_q_networks = self._q_networks.copy(
name='target_q_networks')
self._target_q_networks.requires_grad_(False)
self._update_target = common.TargetUpdater(
models=[self._q_networks],
target_models=[self._target_q_networks],
tau=1 - alpha,
period=target_update_period)

def _trainable_attributes_to_ignore(self):
return ['_target_q_networks']

def after_update(self, root_inputs, info: DQNXInfo):
if self._target_q_networks is not None:
self._update_target()

def _compute_q_values(self, observation, state):
"""
Expand All @@ -167,34 +192,58 @@ def _compute_q_values(self, observation, state):
- action_dist:
- state: the updated state
"""
# [B, num_replicas * q_dim, num_actions]
q_values, state = self._q_networks(observation, state)
q_values = q_values.reshape(
q_values.size(0), self._num_replicas, self._q_dim, -1)
# [B, num_replicas, num_actions, q_dim]
q_values = q_values.transpose(2, 3)

if self._num_replicas == 1:
min_q_values = q_values[:, 0, :, :]
elif self.has_multidim_reward():
sign = self.reward_weights.sign()
min_q_values = (q_values * sign).min(dim=1)[0] * sign

def _calc_q(net, state):
# [B, num_replicas * q_dim, num_actions]
q_values, q_state = net(observation, state)
q_values = q_values.reshape(
q_values.size(0), self._num_replicas, self._q_dim, -1)
# [B, num_replicas, num_actions, q_dim]
q_values = q_values.transpose(2, 3)
return q_values, state

def _min_q(q):
if self._num_replicas == 1:
return q[:, 0, :, :]
elif self.has_multidim_reward():
sign = self.reward_weights.sign()
return (q * sign).min(dim=1)[0] * sign
else:
return q.min(dim=1)[0]

q_values, q_state = _calc_q(self._q_networks, state.q)
min_q_values = _min_q(q_values)

if self._target_q_networks is None:
summed_q_values = min_q_values @ self._reward_weights
# Need this so that the gradient imitation loss will not overwhelm the
# TD loss
summed_q_values = tensor_utils.scale_gradient(
summed_q_values, self._entropy_regularization)
target_q_state = ()
else:
min_q_values = q_values.min(dim=1)[0]

summed_q_values = min_q_values @ self._reward_weights
# Need this so that the gradient imitation loss will not overwhelm the
# TD loss
action_logits = tensor_utils.scale_gradient(
summed_q_values,
self._entropy_regularization) / self._entropy_regularization
with torch.no_grad():
target_q_values, target_q_state = _calc_q(
self._target_q_networks, state.target_q)
# (1-alpha)*q_values + alpha*target_q_values
combined_q_values = torch.lerp(q_values, target_q_values,
self._alpha)
summed_q_values = _min_q(combined_q_values) @ self._reward_weights
# Need this so that the gradient of imitation loss will not overwhelm the
# TD loss
summed_q_values = tensor_utils.scale_gradient(
summed_q_values,
self._entropy_regularization / (1 - self._alpha))

action_logits = summed_q_values / self._entropy_regularization
action_dist = td.Categorical(logits=action_logits)

return q_values, min_q_values, action_dist, state
return q_values, min_q_values, action_dist, DQNXState(
q=q_state, target_q=target_q_state)

def predict_step(self, inputs: TimeStep, state: DQNXState):
_, q_values, action_dist, new_q_state = self._compute_q_values(
inputs.observation, state.q)
_, q_values, action_dist, new_state = self._compute_q_values(
inputs.observation, state)
if self._epsilon_greedy_uniform:
logits = action_dist.logits
greedy_action = logits.argmax(dim=1)
Expand All @@ -209,11 +258,13 @@ def predict_step(self, inputs: TimeStep, state: DQNXState):
action = dist_utils.epsilon_greedy_sample(action_dist,
self._epsilon_greedy)

return AlgStep(output=action, state=DQNXState(q=new_q_state))
return AlgStep(output=action, state=new_state)

def rollout_step(self, inputs: TimeStep, state: DQNXState):
_, q_values, action_dist, new_q_state = self._compute_q_values(
inputs.observation, state.q)
if alf.summary.get_global_counter() == 969:
print("here")
_, q_values, action_dist, new_state = self._compute_q_values(
inputs.observation, state)
action = dist_utils.sample_action_distribution(action_dist)
# [B, num_rewards]
v_values = torch.einsum('bar,ba->br', q_values, action_dist.probs)
Expand All @@ -224,7 +275,7 @@ def rollout_step(self, inputs: TimeStep, state: DQNXState):

return AlgStep(
output=action,
state=DQNXState(q=new_q_state),
state=new_state,
info=DQNXInfo(
action=action,
action_distribution=action_dist,
Expand All @@ -238,15 +289,17 @@ def rollout_step(self, inputs: TimeStep, state: DQNXState):

def train_step(self, inputs: TimeStep, state: DQNXState,
rollout_info: DQNXInfo):
q_values, _, action_dist, new_q_state = self._compute_q_values(
inputs.observation, state.q)
if alf.summary.get_global_counter() == 969:
print("here")
q_values, _, action_dist, new_state = self._compute_q_values(
inputs.observation, state)
action = rollout_info.action
B = torch.arange(action.shape[0])
action_q_values = q_values[B, :, action]
log_pi = action_dist.logits[B, action]
return AlgStep(
output=action,
state=DQNXState(q=new_q_state),
state=new_state,
info=DQNXInfo(
action_distribution=action_dist,
step_type=rollout_info.step_type,
Expand All @@ -265,9 +318,9 @@ def calc_loss(self, info: DQNXInfo):
delta_log_pi = info.log_pi - info.rollout_log_pi
delta_log_pi = delta_log_pi[:, :, None, None]
clipped = ((delta_log_pi > self._delta_log_pi_clip) &
(td_error > 0)) | (
(td_error < 0)) | (
(delta_log_pi < -self._delta_log_pi_clip) &
(td_error < 0))
(td_error > 0))
loss = self._td_error_loss_fn(target_q_values, q_values)
loss = loss * ~clipped
loss = loss.reshape(*loss.shape[:2], -1).mean(-1)
Expand Down Expand Up @@ -304,6 +357,8 @@ def _summarize(v, r, td, suffix):
def preprocess_experience(self, root_inputs: TimeStep, rollout_info,
batch_info):
"""Compute advantages and put it into exp.rollout_info."""
if alf.summary.get_global_counter() == 968:
print("here")

# The device of rollout_info can be different from the default device
# when ReplayBuffer.gather_all.convert_to_default_device is configured
Expand All @@ -313,46 +368,47 @@ def preprocess_experience(self, root_inputs: TimeStep, rollout_info,
discount = convert_device(rollout_info.discount)
reward = convert_device(rollout_info.reward).reshape(B, T, -1)
value = convert_device(rollout_info.v_values)
log_pi = convert_device(rollout_info.log_pi)
discounts = discount * self._gamma

reward_dim = reward.size(-1)
advantages = value_ops.generalized_advantage_estimation(
rewards=reward,
values=value[:, :, :-1],
values=value[:, :, :reward_dim],
step_types=step_type,
discounts=discounts,
td_lambda=self._td_lambda,
time_major=False)
# [B, T, q_dim-1]
advantages = tensor_utils.tensor_extend_zero(advantages, dim=1)
target_q_values = value[:, :, :-1] + advantages

if self._use_entropy_reward:
entropy = convert_device(rollout_info.entropy)
entropy = discount * entropy
if self._alpha > 0:
log_pi = convert_device(rollout_info.log_pi)[:, :-1]
if self._log_pi_clip < 0:
log_pi = log_pi.clamp(
min=self._log_pi_clip / self._entropy_regularization)
entropy[:, 1:] += self._alpha * log_pi
# [B, T-1]
target_q_m = value_ops.one_step_discounted_return(
rewards=self._entropy_regularization * entropy,
values=value[:, :, -1],
step_types=step_type,
discounts=discounts,
time_major=False)
# [B, T]
target_q_m = torch.cat([target_q_m, value[:, -1:, -1]], dim=-1)
elif self._alpha > 0:
log_pi = convert_device(rollout_info.log_pi)
target_q_values = value[:, :, :reward_dim] + advantages

if self._target_q_networks is None:
if self._log_pi_clip < 0:
log_pi = log_pi.clamp(
min=self._log_pi_clip / self._entropy_regularization)
target_q_m = self._alpha * self._entropy_regularization * log_pi

target_q_values = torch.cat(
[target_q_values, target_q_m.unsqueeze(-1)], dim=-1)
if self._use_entropy_reward:
entropy = convert_device(rollout_info.entropy)
entropy = discount * entropy
if self._alpha > 0:
entropy[:, 1:] += self._alpha * log_pi[:, :-1]
# [B, T-1]
target_q_m = value_ops.one_step_discounted_return(
rewards=self._entropy_regularization * entropy,
values=value[:, :, -1],
step_types=step_type,
discounts=discounts,
time_major=False)
# [B, T]
target_q_m = torch.cat([target_q_m, value[:, -1:, -1]], dim=-1)
elif self._alpha > 0:
target_q_m = self._alpha * self._entropy_regularization * log_pi
else:
target_q_m = torch.zeros_like(log_pi)

target_q_values = torch.cat(
[target_q_values, target_q_m.unsqueeze(-1)], dim=-1)
return root_inputs, rollout_info._replace(
target_q_values=target_q_values)

Expand Down
10 changes: 7 additions & 3 deletions alf/examples/dqnx_cart_pole_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@

alf.config(
'DQNXAlgorithm',
epsilon_greedy=0.01,
entropy_regularization=0.1,
alpha=1.0,
epsilon_greedy=1.0,
entropy_regularization=0.01,
alpha=0.95,
target_update_period=1,
use_entropy_reward=False,
log_pi_clip=0,
delta_log_pi_clip=0.2,
Expand All @@ -54,6 +55,7 @@
alf.config(
'TrainerConfig',
algorithm_ctor=Agent,
random_seed=5,
mini_batch_length=1,
unroll_length=32,
mini_batch_size=128,
Expand All @@ -67,6 +69,8 @@
confirm_checkpoint_upon_crash=False,
debug_summaries=True,
summarize_grads_and_vars=True,
update_counter_every_mini_batch=True,
clear_replay_buffer_but_keep_one_step=True,
summary_interval=1)

alf.config('summarize_gradients', with_histogram=False)
Expand Down

0 comments on commit 3530d14

Please sign in to comment.