Skip to content

Commit

Permalink
fix silly error in masked kl div loss, thanks to @taynoel84
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 12, 2023
1 parent 82fa3d0 commit 2bb47d4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def masked_kl_div(prob1, prob2, mask = None):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1)
kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)

if not exists(mask):
return kl_divs.mean()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.66',
version = '0.0.67',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2bb47d4

Please sign in to comment.