From d4faf48d7eda1e95ee9301950de9075df2348e68 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 22 Mar 2023 06:51:23 -0700 Subject: [PATCH] old action log probs should be the true distribution in the kl div loss, addressing https://github.com/lucidrains/PaLM-rlhf-pytorch/issues/43 --- palm_rlhf_pytorch/ppo.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/palm_rlhf_pytorch/ppo.py b/palm_rlhf_pytorch/ppo.py index a6f2461..2b476da 100644 --- a/palm_rlhf_pytorch/ppo.py +++ b/palm_rlhf_pytorch/ppo.py @@ -505,7 +505,7 @@ def learn( kl_div_loss = 0. if self.kl_div_loss_weight > 0: - kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight + kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight # handle non-pooled values diff --git a/setup.py b/setup.py index 8d45f38..cf87dc2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.1.2', + version = '0.1.4', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',