Skip to content

Commit

Permalink
Fix JIT and export of DRL agents
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 18, 2023
1 parent 80e62dc commit e55c619
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 58 deletions.
47 changes: 28 additions & 19 deletions reinforced_lib/agents/deep/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,22 @@ def __init__(
experience_replay=er,
noise=noise
))
self.update = partial(
self.update = jax.jit(partial(
self.update,
q_step_fn=jax.jit(partial(
q_step_fn=partial(
gradient_step,
optimizer=q_optimizer,
loss_fn=partial(self.q_loss_fn, q_network=q_network, a_network=a_network, discount=discount)
)),
a_step_fn=jax.jit(partial(
),
a_step_fn=partial(
gradient_step,
optimizer=a_optimizer,
loss_fn=partial(self.a_loss_fn, q_network=q_network, a_network=a_network)
)),
),
experience_replay=er, experience_replay_steps=experience_replay_steps,
noise_decay=noise_decay, noise_min=noise_min,
tau=tau
)
))
self.sample = jax.jit(partial(
self.sample,
a_network=a_network,
Expand Down Expand Up @@ -308,6 +308,7 @@ def q_loss_fn(
key: PRNGKey,
ddpg_state: DDPGState,
batch: tuple,
non_zero_loss: jnp.bool_,
q_network: hk.TransformedWithState,
a_network: hk.TransformedWithState,
discount: Scalar
Expand All @@ -330,6 +331,8 @@ def q_loss_fn(
The state of the deep deterministic policy gradient agent.
batch : tuple
A batch of transitions from the experience replay buffer.
non_zero_loss : bool
Flag used to avoid updating the Q-network when the experience replay buffer is not full.
q_network : hk.TransformedWithState
The Q-network.
a_network : hk.TransformedWithState
Expand All @@ -355,14 +358,15 @@ def q_loss_fn(
target = jax.lax.stop_gradient(target)
loss = optax.l2_loss(q_values, target).mean()

return loss, q_state
return loss * non_zero_loss, q_state

@staticmethod
def a_loss_fn(
a_params: hk.Params,
key: PRNGKey,
ddpg_state: DDPGState,
batch: tuple,
non_zero_loss: jnp.bool_,
q_network: hk.TransformedWithState,
a_network: hk.TransformedWithState
) -> tuple[Scalar, hk.State]:
Expand All @@ -381,6 +385,8 @@ def a_loss_fn(
The state of the deep deterministic policy gradient agent.
batch : tuple
A batch of transitions from the experience replay buffer.
non_zero_loss : bool
Flag used to avoid updating the policy network when the experience replay buffer is not full.
q_network : hk.TransformedWithState
The Q-network.
a_network : hk.TransformedWithState
Expand All @@ -397,9 +403,9 @@ def a_loss_fn(

actions, a_state = a_network.apply(a_params, ddpg_state.a_state, a_key, states)
q_values, _ = q_network.apply(ddpg_state.q_params, ddpg_state.q_state, q_key, states, actions)

loss = -jnp.mean(q_values)
return loss, a_state

return loss * non_zero_loss, a_state

@staticmethod
def update(
Expand Down Expand Up @@ -470,18 +476,21 @@ def update(
a_params, a_net_state, a_opt_state = state.a_params, state.a_state, state.a_opt_state
a_params_target, a_state_target = state.a_params_target, state.a_state_target

if experience_replay.is_ready(replay_buffer):
for _ in range(experience_replay_steps):
batch_key, q_network_key, a_network_key, key = jax.random.split(key, 4)
batch = experience_replay.sample(replay_buffer, batch_key)
non_zero_loss = experience_replay.is_ready(replay_buffer)

for _ in range(experience_replay_steps):
batch_key, q_network_key, a_network_key, key = jax.random.split(key, 4)
batch = experience_replay.sample(replay_buffer, batch_key)

q_params, q_net_state, q_opt_state, _ = q_step_fn(q_params, (q_network_key, state, batch), q_opt_state)
a_params, a_net_state, a_opt_state, _ = a_step_fn(a_params, (a_network_key, state, batch), a_opt_state)
q_params, q_net_state, q_opt_state, _ = q_step_fn(
q_params, (q_network_key, state, batch, non_zero_loss), q_opt_state)
a_params, a_net_state, a_opt_state, _ = a_step_fn(
a_params, (a_network_key, state, batch, non_zero_loss), a_opt_state)

q_params_target, q_state_target = optax.incremental_update(
(q_params, q_net_state), (q_params_target, q_state_target), tau)
a_params_target, a_state_target = optax.incremental_update(
(a_params, a_net_state), (a_params_target, a_state_target), tau)
q_params_target, q_state_target = optax.incremental_update(
(q_params, q_net_state), (q_params_target, q_state_target), tau)
a_params_target, a_state_target = optax.incremental_update(
(a_params, a_net_state), (a_params_target, a_state_target), tau)

return DDPGState(
q_params=q_params,
Expand Down
28 changes: 16 additions & 12 deletions reinforced_lib/agents/deep/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,19 @@ def __init__(
experience_replay=er,
epsilon=epsilon
))
self.update = partial(
self.update = jax.jit(partial(
self.update,
step_fn=jax.jit(partial(
step_fn=partial(
gradient_step,
optimizer=optimizer,
loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount)
)),
),
experience_replay=er,
experience_replay_steps=experience_replay_steps,
epsilon_decay=epsilon_decay,
epsilon_min=epsilon_min,
tau=tau
)
))
self.sample = jax.jit(partial(
self.sample,
q_network=q_network,
Expand Down Expand Up @@ -245,6 +245,7 @@ def loss_fn(
key: PRNGKey,
dqn_state: DQNState,
batch: tuple,
non_zero_loss: jnp.bool_,
q_network: hk.TransformedWithState,
discount: Scalar
) -> tuple[Scalar, hk.State]:
Expand All @@ -265,6 +266,8 @@ def loss_fn(
The state of the double Q-learning agent.
batch : tuple
A batch of transitions from the experience replay buffer.
non_zero_loss : bool
Flag used to avoid updating the Q-network when the experience replay buffer is not full.
q_network : hk.TransformedWithState
The Q-network.
discount : Scalar
Expand All @@ -288,7 +291,7 @@ def loss_fn(
target = jax.lax.stop_gradient(target)
loss = optax.l2_loss(q_values, target).mean()

return loss, state
return loss * non_zero_loss, state

@staticmethod
def update(
Expand Down Expand Up @@ -352,14 +355,15 @@ def update(
params, net_state, opt_state = state.params, state.state, state.opt_state
params_target, state_target = state.params_target, state.state_target

if experience_replay.is_ready(replay_buffer):
for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)
non_zero_loss = experience_replay.is_ready(replay_buffer)

for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)

params, net_state, opt_state, _ = step_fn(params, (network_key, state, batch), opt_state)
params_target, state_target = optax.incremental_update(
(params, net_state), (params_target, state_target), tau)
params, net_state, opt_state, _ = step_fn(params, (network_key, state, batch, non_zero_loss), opt_state)
params_target, state_target = optax.incremental_update(
(params, net_state), (params_target, state_target), tau)

return DQNState(
params=params,
Expand Down
30 changes: 16 additions & 14 deletions reinforced_lib/agents/deep/expected_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,17 @@ def __init__(
optimizer=optimizer,
experience_replay=er
))
self.update = partial(
self.update = jax.jit(partial(
self.update,
q_network=q_network,
step_fn=jax.jit(partial(
step_fn=partial(
gradient_step,
optimizer=optimizer,
loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount, tau=tau)
)),
),
experience_replay=er,
experience_replay_steps=experience_replay_steps
)
))
self.sample = jax.jit(partial(
self.sample,
q_network=q_network,
Expand Down Expand Up @@ -208,6 +208,7 @@ def loss_fn(
params_target: hk.Params,
net_state_target: hk.State,
batch: tuple,
non_zero_loss: jnp.bool_,
q_network: hk.TransformedWithState,
discount: Scalar,
tau: Scalar
Expand All @@ -233,6 +234,8 @@ def loss_fn(
The state of the target Q-network.
batch : tuple
A batch of transitions from the experience replay buffer.
non_zero_loss : bool
Flag used to avoid updating the Q-network when the experience replay buffer is not full.
q_network : hk.TransformedWithState
The Q-network.
discount : Scalar
Expand All @@ -259,7 +262,7 @@ def loss_fn(
target = jax.lax.stop_gradient(target)
loss = optax.l2_loss(q_values, target).mean()

return loss, state
return loss * non_zero_loss, state

@staticmethod
def update(
Expand Down Expand Up @@ -314,17 +317,16 @@ def update(
)

params, net_state, opt_state = state.params, state.state, state.opt_state
params_target, net_state_target = deepcopy(params), deepcopy(net_state)

non_zero_loss = experience_replay.is_ready(replay_buffer)

if experience_replay.is_ready(replay_buffer):
params_target = deepcopy(params)
net_state_target = deepcopy(net_state)

for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)
for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)

loss_params = (network_key, net_state, params_target, net_state_target, batch)
params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state)
loss_params = (network_key, net_state, params_target, net_state_target, batch, non_zero_loss)
params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state)

return ExpectedSarsaState(
params=params,
Expand Down
28 changes: 15 additions & 13 deletions reinforced_lib/agents/deep/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,18 @@ def __init__(
experience_replay=er,
epsilon=epsilon
))
self.update = partial(
self.update = jax.jit(partial(
self.update,
step_fn=jax.jit(partial(
step_fn=partial(
gradient_step,
optimizer=optimizer,
loss_fn=partial(self.loss_fn, q_network=q_network, discount=discount)
)),
),
experience_replay=er,
experience_replay_steps=experience_replay_steps,
epsilon_decay=epsilon_decay,
epsilon_min=epsilon_min
)
))
self.sample = jax.jit(partial(
self.sample,
q_network=q_network,
Expand Down Expand Up @@ -230,6 +230,7 @@ def loss_fn(
params_target: hk.Params,
net_state_target: hk.State,
batch: tuple,
non_zero_loss: jnp.bool_,
q_network: hk.TransformedWithState,
discount: Scalar
) -> tuple[Scalar, hk.State]:
Expand All @@ -253,6 +254,8 @@ def loss_fn(
The state of the target Q-network.
batch : tuple
A batch of transitions from the experience replay buffer.
non_zero_loss : bool
Flag used to avoid updating the Q-network when the experience replay buffer is not full.
q_network : hk.TransformedWithState
The Q-network.
discount : Scalar
Expand All @@ -276,7 +279,7 @@ def loss_fn(
target = jax.lax.stop_gradient(target)
loss = optax.l2_loss(q_values, target).mean()

return loss, state
return loss * non_zero_loss, state

@staticmethod
def update(
Expand Down Expand Up @@ -335,17 +338,16 @@ def update(
)

params, net_state, opt_state = state.params, state.state, state.opt_state
params_target, net_state_target = deepcopy(params), deepcopy(net_state)

if experience_replay.is_ready(replay_buffer):
params_target = deepcopy(params)
net_state_target = deepcopy(net_state)
non_zero_loss = experience_replay.is_ready(replay_buffer)

for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)
for _ in range(experience_replay_steps):
batch_key, network_key, key = jax.random.split(key, 3)
batch = experience_replay.sample(replay_buffer, batch_key)

loss_params = (network_key, net_state, params_target, net_state_target, batch)
params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state)
loss_params = (network_key, net_state, params_target, net_state_target, batch, non_zero_loss)
params, net_state, opt_state, _ = step_fn(params, loss_params, opt_state)

return QLearningState(
params=params,
Expand Down

0 comments on commit e55c619

Please sign in to comment.