diff --git a/reinforced_lib/agents/deep/ddpg.py b/reinforced_lib/agents/deep/ddpg.py index 503d619..fb99f69 100644 --- a/reinforced_lib/agents/deep/ddpg.py +++ b/reinforced_lib/agents/deep/ddpg.py @@ -174,22 +174,22 @@ def __init__( experience_replay=er, noise=noise )) - self.update = partial( + self.update = jax.jit(partial( self.update, - q_step_fn=jax.jit(partial( + q_step_fn=partial( gradient_step, optimizer=q_optimizer, loss_fn=partial(self.q_loss_fn, q_network=q_network, a_network=a_network, discount=discount) - )), - a_step_fn=jax.jit(partial( + ), + a_step_fn=partial( gradient_step, optimizer=a_optimizer, loss_fn=partial(self.a_loss_fn, q_network=q_network, a_network=a_network) - )), + ), experience_replay=er, experience_replay_steps=experience_replay_steps, noise_decay=noise_decay, noise_min=noise_min, tau=tau - ) + )) self.sample = jax.jit(partial( self.sample, a_network=a_network, @@ -308,6 +308,7 @@ def q_loss_fn( key: PRNGKey, ddpg_state: DDPGState, batch: tuple, + non_zero_loss: jnp.bool_, q_network: hk.TransformedWithState, a_network: hk.TransformedWithState, discount: Scalar @@ -330,6 +331,8 @@ def q_loss_fn( The state of the deep deterministic policy gradient agent. batch : tuple A batch of transitions from the experience replay buffer. + non_zero_loss : bool + Flag used to avoid updating the Q-network when the experience replay buffer is not full. q_network : hk.TransformedWithState The Q-network. a_network : hk.TransformedWithState @@ -355,7 +358,7 @@ def q_loss_fn( target = jax.lax.stop_gradient(target) loss = optax.l2_loss(q_values, target).mean() - return loss, q_state + return loss * non_zero_loss, q_state @staticmethod def a_loss_fn( @@ -363,6 +366,7 @@ def a_loss_fn( key: PRNGKey, ddpg_state: DDPGState, batch: tuple, + non_zero_loss: jnp.bool_, q_network: hk.TransformedWithState, a_network: hk.TransformedWithState ) -> tuple[Scalar, hk.State]: @@ -381,6 +385,8 @@ def a_loss_fn( The state of the deep deterministic policy gradient agent. batch : tuple A batch of transitions from the experience replay buffer. + non_zero_loss : bool + Flag used to avoid updating the policy network when the experience replay buffer is not full. q_network : hk.TransformedWithState The Q-network. a_network : hk.TransformedWithState @@ -397,9 +403,9 @@ def a_loss_fn( actions, a_state = a_network.apply(a_params, ddpg_state.a_state, a_key, states) q_values, _ = q_network.apply(ddpg_state.q_params, ddpg_state.q_state, q_key, states, actions) - loss = -jnp.mean(q_values) - return loss, a_state + + return loss * non_zero_loss, a_state @staticmethod def update( @@ -470,18 +476,21 @@ def update( a_params, a_net_state, a_opt_state = state.a_params, state.a_state, state.a_opt_state a_params_target, a_state_target = state.a_params_target, state.a_state_target - if experience_replay.is_ready(replay_buffer): - for _ in range(experience_replay_steps): - batch_key, q_network_key, a_network_key, key = jax.random.split(key, 4) - batch = experience_replay.sample(replay_buffer, batch_key) + non_zero_loss = experience_replay.is_ready(replay_buffer) + + for _ in range(experience_replay_steps): + batch_key, q_network_key, a_network_key, key = jax.random.split(key, 4) + batch = experience_replay.sample(replay_buffer, batch_key) - q_params, q_net_state, q_opt_state, _ = q_step_fn(q_params, (q_network_key, state, batch), q_opt_state) - a_params, a_net_state, a_opt_state, _ = a_step_fn(a_params, (a_network_key, state, batch), a_opt_state) + q_params, q_net_state, q_opt_state, _ = q_step_fn( + q_params, (q_network_key, state, batch, non_zero_loss), q_opt_state) + a_params, a_net_state, a_opt_state, _ = a_step_fn( + a_params, (a_network_key, state, batch, non_zero_loss), a_opt_state) - q_params_target, q_state_target = optax.incremental_update( - (q_params, q_net_state), (q_params_target, q_state_target), tau) - a_params_target, a_state_target = optax.incremental_update( - (a_params, a_net_state), (a_params_target, a_state_target), tau) + q_params_target, q_state_target = optax.incremental_update( + (q_params, q_net_state), (q_params_target, q_state_target), tau) + a_params_target, a_state_target = optax.incremental_update( + (a_params, a_net_state), (a_params_target, a_state_target), tau) return DDPGState( q_params=q_params, diff --git a/reinforced_lib/agents/deep/dqn.py b/reinforced_lib/agents/deep/dqn.py index 27539e4..5b52302 100644 --- a/reinforced_lib/agents/deep/dqn.py +++ b/reinforced_lib/agents/deep/dqn.py @@ -136,19 +136,19 @@ def __init__( experience_replay=er, epsilon=epsilon )) - self.update = partial( + self.update = jax.jit(partial( self.update, - step_fn=jax.jit(partial( + step_fn=partial( gradient_step, optimizer=optimizer, loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount) - )), + ), experience_replay=er, experience_replay_steps=experience_replay_steps, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min, tau=tau - ) + )) self.sample = jax.jit(partial( self.sample, q_network=q_network, @@ -245,6 +245,7 @@ def loss_fn( key: PRNGKey, dqn_state: DQNState, batch: tuple, + non_zero_loss: jnp.bool_, q_network: hk.TransformedWithState, discount: Scalar ) -> tuple[Scalar, hk.State]: @@ -265,6 +266,8 @@ def loss_fn( The state of the double Q-learning agent. batch : tuple A batch of transitions from the experience replay buffer. + non_zero_loss : bool + Flag used to avoid updating the Q-network when the experience replay buffer is not full. q_network : hk.TransformedWithState The Q-network. discount : Scalar @@ -288,7 +291,7 @@ def loss_fn( target = jax.lax.stop_gradient(target) loss = optax.l2_loss(q_values, target).mean() - return loss, state + return loss * non_zero_loss, state @staticmethod def update( @@ -352,14 +355,15 @@ def update( params, net_state, opt_state = state.params, state.state, state.opt_state params_target, state_target = state.params_target, state.state_target - if experience_replay.is_ready(replay_buffer): - for _ in range(experience_replay_steps): - batch_key, network_key, key = jax.random.split(key, 3) - batch = experience_replay.sample(replay_buffer, batch_key) + non_zero_loss = experience_replay.is_ready(replay_buffer) + + for _ in range(experience_replay_steps): + batch_key, network_key, key = jax.random.split(key, 3) + batch = experience_replay.sample(replay_buffer, batch_key) - params, net_state, opt_state, _ = step_fn(params, (network_key, state, batch), opt_state) - params_target, state_target = optax.incremental_update( - (params, net_state), (params_target, state_target), tau) + params, net_state, opt_state, _ = step_fn(params, (network_key, state, batch, non_zero_loss), opt_state) + params_target, state_target = optax.incremental_update( + (params, net_state), (params_target, state_target), tau) return DQNState( params=params, diff --git a/reinforced_lib/agents/deep/expected_sarsa.py b/reinforced_lib/agents/deep/expected_sarsa.py index 20bb22a..9ab6fc5 100644 --- a/reinforced_lib/agents/deep/expected_sarsa.py +++ b/reinforced_lib/agents/deep/expected_sarsa.py @@ -107,17 +107,17 @@ def __init__( optimizer=optimizer, experience_replay=er )) - self.update = partial( + self.update = jax.jit(partial( self.update, q_network=q_network, - step_fn=jax.jit(partial( + step_fn=partial( gradient_step, optimizer=optimizer, loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount, tau=tau) - )), + ), experience_replay=er, experience_replay_steps=experience_replay_steps - ) + )) self.sample = jax.jit(partial( self.sample, q_network=q_network, @@ -208,6 +208,7 @@ def loss_fn( params_target: hk.Params, net_state_target: hk.State, batch: tuple, + non_zero_loss: jnp.bool_, q_network: hk.TransformedWithState, discount: Scalar, tau: Scalar @@ -233,6 +234,8 @@ def loss_fn( The state of the target Q-network. batch : tuple A batch of transitions from the experience replay buffer. + non_zero_loss : bool + Flag used to avoid updating the Q-network when the experience replay buffer is not full. q_network : hk.TransformedWithState The Q-network. discount : Scalar @@ -259,7 +262,7 @@ def loss_fn( target = jax.lax.stop_gradient(target) loss = optax.l2_loss(q_values, target).mean() - return loss, state + return loss * non_zero_loss, state @staticmethod def update( @@ -314,17 +317,16 @@ def update( ) params, net_state, opt_state = state.params, state.state, state.opt_state + params_target, net_state_target = deepcopy(params), deepcopy(net_state) + + non_zero_loss = experience_replay.is_ready(replay_buffer) - if experience_replay.is_ready(replay_buffer): - params_target = deepcopy(params) - net_state_target = deepcopy(net_state) - - for _ in range(experience_replay_steps): - batch_key, network_key, key = jax.random.split(key, 3) - batch = experience_replay.sample(replay_buffer, batch_key) + for _ in range(experience_replay_steps): + batch_key, network_key, key = jax.random.split(key, 3) + batch = experience_replay.sample(replay_buffer, batch_key) - loss_params = (network_key, net_state, params_target, net_state_target, batch) - params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state) + loss_params = (network_key, net_state, params_target, net_state_target, batch, non_zero_loss) + params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state) return ExpectedSarsaState( params=params, diff --git a/reinforced_lib/agents/deep/q_learning.py b/reinforced_lib/agents/deep/q_learning.py index f0f4428..2676dc6 100644 --- a/reinforced_lib/agents/deep/q_learning.py +++ b/reinforced_lib/agents/deep/q_learning.py @@ -123,18 +123,18 @@ def __init__( experience_replay=er, epsilon=epsilon )) - self.update = partial( + self.update = jax.jit(partial( self.update, - step_fn=jax.jit(partial( + step_fn=partial( gradient_step, optimizer=optimizer, loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount) - )), + ), experience_replay=er, experience_replay_steps=experience_replay_steps, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min - ) + )) self.sample = jax.jit(partial( self.sample, q_network=q_network, @@ -230,6 +230,7 @@ def loss_fn( params_target: hk.Params, net_state_target: hk.State, batch: tuple, + non_zero_loss: jnp.bool_, q_network: hk.TransformedWithState, discount: Scalar ) -> tuple[Scalar, hk.State]: @@ -253,6 +254,8 @@ def loss_fn( The state of the target Q-network. batch : tuple A batch of transitions from the experience replay buffer. + non_zero_loss : bool + Flag used to avoid updating the Q-network when the experience replay buffer is not full. q_network : hk.TransformedWithState The Q-network. discount : Scalar @@ -276,7 +279,7 @@ def loss_fn( target = jax.lax.stop_gradient(target) loss = optax.l2_loss(q_values, target).mean() - return loss, state + return loss * non_zero_loss, state @staticmethod def update( @@ -335,17 +338,16 @@ def update( ) params, net_state, opt_state = state.params, state.state, state.opt_state + params_target, net_state_target = deepcopy(params), deepcopy(net_state) - if experience_replay.is_ready(replay_buffer): - params_target = deepcopy(params) - net_state_target = deepcopy(net_state) + non_zero_loss = experience_replay.is_ready(replay_buffer) - for _ in range(experience_replay_steps): - batch_key, network_key, key = jax.random.split(key, 3) - batch = experience_replay.sample(replay_buffer, batch_key) + for _ in range(experience_replay_steps): + batch_key, network_key, key = jax.random.split(key, 3) + batch = experience_replay.sample(replay_buffer, batch_key) - loss_params = (network_key, net_state, params_target, net_state_target, batch) - params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state) + loss_params = (network_key, net_state, params_target, net_state_target, batch, non_zero_loss) + params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state) return QLearningState( params=params,