Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanmo committed May 14, 2024
1 parent 6997269 commit 7eb526e
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions examples/rlexplore/3 rlexplore_with_cleanrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,9 @@ def get_action_and_value(self, x, action=None):
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())

# ===================== watch the interaction ===================== #
irs.watch(observations=obs[step],
actions=actions[step],
rewards=rewards[step],
terminateds=dones[step],
truncateds=dones[step],
next_observations=next_obs
)
irs.watch(observations=obs[step], actions=actions[step],
rewards=rewards[step], terminateds=dones[step],
truncateds=dones[step], next_observations=next_obs)
# ===================== watch the interaction ===================== #

next_done = np.logical_or(terminations, truncations)
Expand All @@ -255,12 +251,18 @@ def get_action_and_value(self, x, action=None):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# ===================== compute the intrinsic rewards ===================== #
intrinsic_rewards = irs.compute(samples=dict(observations=obs,
actions=actions,
rewards=rewards,
terminateds=dones,
truncateds=dones,
next_observations=obs
next_obs = obs.clone()
next_obs[:-1] = obs[1:]
next_obs[-1] = next_obs

# get real next observations
real_next_obs = obs.clone()
real_next_obs[:-1] = obs[1:]
real_next_obs[-1] = next_obs

intrinsic_rewards = irs.compute(samples=dict(observations=obs, actions=actions,
rewards=rewards, terminateds=dones,
truncateds=dones, next_observations=real_next_obs
))
rewards += intrinsic_rewards
# ===================== compute the intrinsic rewards ===================== #
Expand Down

0 comments on commit 7eb526e

Please sign in to comment.