-
Notifications
You must be signed in to change notification settings - Fork 672
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add what may be a tiny breakthrough, which happened earlier last mont…
…h, which led to Prime by the same author
- Loading branch information
1 parent
00c118d
commit 984e666
Showing
3 changed files
with
137 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from beartype import beartype | ||
|
||
from palm_rlhf_pytorch import PaLM | ||
|
||
import torch | ||
from torch.nn import Module | ||
from torch.nn.functional import logsigmoid | ||
|
||
from einx import get_at | ||
from einops import rearrange | ||
|
||
# helpers | ||
|
||
def exists(v): | ||
return v is not None | ||
|
||
# Free Process Rewards without Process Labels | ||
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime | ||
|
||
class ImplicitPRM(Module): | ||
""" PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """ | ||
|
||
@beartype | ||
def __init__( | ||
self, | ||
model: Module, | ||
ref_model: Module | None = None, | ||
beta = 0.1 | ||
): | ||
super().__init__() | ||
self.model = model | ||
|
||
# only drawback to this technique is needing a reference model | ||
|
||
if not exists(ref_model): | ||
ref_model = deepcopy(model) | ||
|
||
self.ref_model = ref_model | ||
ref_model.requires_grad_(False) # insurance | ||
|
||
self.beta = beta | ||
|
||
def parameters(self): | ||
return self.model.parameters() # only main model is trained | ||
|
||
def forward( | ||
self, | ||
seq, | ||
labels = None | ||
): | ||
""" | ||
b - batch | ||
n - sequence | ||
l - logit dimension (num tokens) | ||
""" | ||
|
||
return_loss = exists(labels) | ||
|
||
model_logits = self.model(seq) | ||
ref_model_logits = self.ref_model(seq) | ||
|
||
log_probs = model_logits.log_softmax(dim = -1) | ||
ref_log_probs = ref_model_logits.log_softmax(dim = -1) | ||
|
||
log_prob = get_at('b n [l], b n -> b n', log_probs, seq) | ||
ref_log_prob = get_at('b n [l], b n -> b n', ref_log_probs, seq) | ||
|
||
# main formula is DPO-like, and has some connection with Q-learning https://arxiv.org/abs/2404.12358 . it is all connected | ||
|
||
implicit_rewards = self.beta * (log_prob - ref_log_prob) | ||
|
||
# early return if not training, as in Prime with alternating model and prm training | ||
|
||
if not return_loss: | ||
return implicit_rewards | ||
|
||
labels = rearrange(labels, 'b -> b 1') | ||
|
||
# otherwise use the cross entropy formulation from their paper (eq 5) | ||
|
||
loss = ( | ||
labels * logsigmoid(implicit_rewards) + | ||
(1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x) | ||
) | ||
|
||
return loss.mean() | ||
|
||
# make it easy for others to copy paste into another project | ||
|
||
if __name__ == '__main__': | ||
palm = PaLM( | ||
num_tokens = 256, | ||
dim = 64, | ||
depth = 2 | ||
) | ||
|
||
ref_palm = PaLM( | ||
num_tokens = 256, | ||
dim = 64, | ||
depth = 2 | ||
) | ||
|
||
implicit_prm = ImplicitPRM( | ||
palm, | ||
ref_model = ref_palm | ||
) | ||
|
||
# mock data | ||
|
||
seq = torch.randint(0, 256, (2, 1024)) | ||
labels = torch.randint(0, 2, (2,)) | ||
|
||
loss = implicit_prm(seq, labels) | ||
loss.backward() | ||
|
||
# after much training | ||
|
||
implicit_rewards = implicit_prm(seq) # Float[2, 1024] | ||
|
||
# there you go, free process reward model | ||
# now you can use this dense reward for rlhf, beam search whatever |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters