From 2cc7b463aa7bc56ce49287d1394af5350ec1f714 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 6 Jan 2025 07:27:13 -0800 Subject: [PATCH] take care of variable lengthed responses for implicit PRM --- palm_rlhf_pytorch/implicit_process_reward.py | 13 ++++++++++--- setup.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/palm_rlhf_pytorch/implicit_process_reward.py b/palm_rlhf_pytorch/implicit_process_reward.py index 1694dc2..67eb9aa 100644 --- a/palm_rlhf_pytorch/implicit_process_reward.py +++ b/palm_rlhf_pytorch/implicit_process_reward.py @@ -3,8 +3,6 @@ from copy import deepcopy from beartype import beartype -from palm_rlhf_pytorch.palm import PaLM - import torch from torch.nn import Module from torch.nn.functional import logsigmoid @@ -56,8 +54,11 @@ def forward( n - sequence l - logit dimension (num tokens) """ + source_seq, target_seq = seq[:, :-1], seq[:, 1:] + mask = target_seq >= 0 # assume any token ids < 0 to be padding + model_logits = self.model(source_seq) ref_model_logits = self.ref_model(source_seq) @@ -71,6 +72,10 @@ def forward( implicit_rewards = self.beta * (log_prob - ref_log_prob) + # zero out rewards in padding + + implicit_rewards = implicit_rewards.masked_fill(~mask, 0.) + # early return if not training, as in Prime with alternating model and prm training if not exists(labels): @@ -85,11 +90,13 @@ def forward( (1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x) ) - return loss.mean() + return loss[mask].mean() # make it easy for others to copy paste into another project if __name__ == '__main__': + from palm_rlhf_pytorch import PaLM + palm = PaLM( num_tokens = 256, dim = 64, diff --git a/setup.py b/setup.py index 1c8a8f9..3b4eec7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.3.3', + version = '0.3.4', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',