Skip to content

Commit

Permalink
add what may be a tiny breakthrough, which happened earlier last mont…
Browse files Browse the repository at this point in the history
…h, which led to Prime by the same author
  • Loading branch information
lucidrains committed Jan 6, 2025
1 parent 00c118d commit 984e666
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,12 @@ answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 204
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
}
```

```bibtex
@inproceedings{Yuan2024FreePR,
title = {Free Process Rewards without Process Labels},
author = {Lifan Yuan and Wendi Li and Huayu Chen and Ganqu Cui and Ning Ding and Kaiyan Zhang and Bowen Zhou and Zhiyuan Liu and Hao Peng},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274445748}
}
```
124 changes: 124 additions & 0 deletions palm_rlhf_pytorch/implicit_process_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from copy import deepcopy
from beartype import beartype

from palm_rlhf_pytorch import PaLM

import torch
from torch.nn import Module
from torch.nn.functional import logsigmoid

from einx import get_at
from einops import rearrange

# helpers

def exists(v):
return v is not None

# Free Process Rewards without Process Labels
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime

class ImplicitPRM(Module):
""" PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """

@beartype
def __init__(
self,
model: Module,
ref_model: Module | None = None,
beta = 0.1
):
super().__init__()
self.model = model

# only drawback to this technique is needing a reference model

if not exists(ref_model):
ref_model = deepcopy(model)

self.ref_model = ref_model
ref_model.requires_grad_(False) # insurance

self.beta = beta

def parameters(self):
return self.model.parameters() # only main model is trained

def forward(
self,
seq,
labels = None
):
"""
b - batch
n - sequence
l - logit dimension (num tokens)
"""

return_loss = exists(labels)

model_logits = self.model(seq)
ref_model_logits = self.ref_model(seq)

log_probs = model_logits.log_softmax(dim = -1)
ref_log_probs = ref_model_logits.log_softmax(dim = -1)

log_prob = get_at('b n [l], b n -> b n', log_probs, seq)
ref_log_prob = get_at('b n [l], b n -> b n', ref_log_probs, seq)

# main formula is DPO-like, and has some connection with Q-learning https://arxiv.org/abs/2404.12358 . it is all connected

implicit_rewards = self.beta * (log_prob - ref_log_prob)

# early return if not training, as in Prime with alternating model and prm training

if not return_loss:
return implicit_rewards

labels = rearrange(labels, 'b -> b 1')

# otherwise use the cross entropy formulation from their paper (eq 5)

loss = (
labels * logsigmoid(implicit_rewards) +
(1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x)
)

return loss.mean()

# make it easy for others to copy paste into another project

if __name__ == '__main__':
palm = PaLM(
num_tokens = 256,
dim = 64,
depth = 2
)

ref_palm = PaLM(
num_tokens = 256,
dim = 64,
depth = 2
)

implicit_prm = ImplicitPRM(
palm,
ref_model = ref_palm
)

# mock data

seq = torch.randint(0, 256, (2, 1024))
labels = torch.randint(0, 2, (2,))

loss = implicit_prm(seq, labels)
loss.backward()

# after much training

implicit_rewards = implicit_prm(seq) # Float[2, 1024]

# there you go, free process reward model
# now you can use this dense reward for rlhf, beam search whatever
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.4',
version = '0.3.0',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand All @@ -22,9 +22,10 @@
'accelerate',
'adam-atan2-pytorch',
'beartype',
'einops>=0.6',
'einx>=0.3.0',
'einops>=0.8',
'lion-pytorch',
'torch>=1.6',
'torch>=2.2',
'tqdm'
],
classifiers=[
Expand Down

0 comments on commit 984e666

Please sign in to comment.