Skip to content

Commit

Permalink
old action log probs should be the true distribution in the kl div lo…
Browse files Browse the repository at this point in the history
…ss, addressing #43
  • Loading branch information
lucidrains committed Mar 22, 2023
1 parent ad001d8 commit d4faf48
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def learn(
kl_div_loss = 0.

if self.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight
kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight

# handle non-pooled values

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d4faf48

Please sign in to comment.