Skip to content

Commit

Permalink
take care of variable lengthed responses for implicit PRM
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2025
1 parent 41a1a22 commit 2cc7b46
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions palm_rlhf_pytorch/implicit_process_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from copy import deepcopy
from beartype import beartype

from palm_rlhf_pytorch.palm import PaLM

import torch
from torch.nn import Module
from torch.nn.functional import logsigmoid
Expand Down Expand Up @@ -56,8 +54,11 @@ def forward(
n - sequence
l - logit dimension (num tokens)
"""

source_seq, target_seq = seq[:, :-1], seq[:, 1:]

mask = target_seq >= 0 # assume any token ids < 0 to be padding

model_logits = self.model(source_seq)
ref_model_logits = self.ref_model(source_seq)

Expand All @@ -71,6 +72,10 @@ def forward(

implicit_rewards = self.beta * (log_prob - ref_log_prob)

# zero out rewards in padding

implicit_rewards = implicit_rewards.masked_fill(~mask, 0.)

# early return if not training, as in Prime with alternating model and prm training

if not exists(labels):
Expand All @@ -85,11 +90,13 @@ def forward(
(1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x)
)

return loss.mean()
return loss[mask].mean()

# make it easy for others to copy paste into another project

if __name__ == '__main__':
from palm_rlhf_pytorch import PaLM

palm = PaLM(
num_tokens = 256,
dim = 64,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.3',
version = '0.3.4',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2cc7b46

Please sign in to comment.