Skip to content

Commit

Permalink
start wiring up dense rewarding with implicit prm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2025
1 parent f3e20cf commit 89ab8ba
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
61 changes: 42 additions & 19 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator

from accelerate import Accelerator

# actor critic - PaLM with lora
Expand All @@ -47,7 +47,7 @@ class ActorCritic(Module):
def __init__(
self,
palm: PaLM,
critic_palm: PaLM | None = None,
critic: PaLM | ImplicitPRM | None = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
Expand All @@ -61,13 +61,26 @@ def __init__(
super().__init__()
self.actor_palm = palm

self.critic_palm = critic_palm
# detect implicit prm and auto-set some hyperparameters

critic_is_prm = isinstance(critic, ImplicitPRM)

critic_lora &= not critic_is_prm
pooled_values |= critic_is_prm

self.critic_is_prm = critic_is_prm

# critic

self.critic = critic

if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(palm)
if not exists(self.critic):
self.critic = copy.deepcopy(palm)

self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)

if not critic_is_prm:
self.critic.set_dropout(critic_dropout)

self.actor_lora = actor_lora
self.critic_lora = critic_lora
Expand All @@ -79,16 +92,19 @@ def __init__(
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)

if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)

self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
Rearrange('... 1 -> ...')
)
self.value_head = nn.Identity()

if not critic_is_prm:
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
Rearrange('... 1 -> ...')
)

nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))

def actor_parameters(self):
if not self.actor_lora:
Expand All @@ -99,11 +115,14 @@ def actor_parameters(self):
]

def critic_parameters(self):
if self.critic_is_prm:
return self.critic.parameters()

if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [*self.critic.parameters(), *self.value_head.parameters()]

return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.critic.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]

Expand Down Expand Up @@ -170,7 +189,11 @@ def forward(
if not return_values:
return action_logits, None

critic_embeds = self.critic_palm(
if self.critic_is_prm:
values = self.critic(x)
return action_logits, values

critic_embeds = self.critic(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
Expand Down Expand Up @@ -287,8 +310,8 @@ def clipped_value_loss(values, rewards, old_values, clip):

# rlhf trainer

@beartype
class RLHFTrainer(Module):
@beartype
def __init__(
self,
*,
Expand All @@ -298,7 +321,7 @@ def __init__(
tokenizer: Callable | None = None,
palm: PaLM,
reward_model: RewardModel,
critic_palm: PaLM | None = None,
critic: PaLM | ImplicitPRM | None = None,
actor_critic: ActorCritic | None = None,
actor_lr = 1e-4,
critic_lr = 1e-4,
Expand Down Expand Up @@ -351,7 +374,7 @@ def __init__(
if not exists(actor_critic):
actor_critic = ActorCritic(
palm = palm,
critic_palm = critic_palm,
critic = critic,
actor_lora = actor_lora,
critic_lora = critic_lora,
actor_lora_r = actor_lora_r,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.7',
version = '0.3.9',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 89ab8ba

Please sign in to comment.