diff --git a/palm_rlhf_pytorch/ppo.py b/palm_rlhf_pytorch/ppo.py index 2b476da..fbfa4b7 100644 --- a/palm_rlhf_pytorch/ppo.py +++ b/palm_rlhf_pytorch/ppo.py @@ -17,7 +17,7 @@ from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence -from einops import rearrange, repeat +from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from palm_rlhf_pytorch.palm import PaLM @@ -263,16 +263,17 @@ def masked_entropy(prob, dim = -1, mask = None): entropies = (prob * log(prob)).sum(dim = -1) return masked_mean(entropies, mask = mask).mean() -def masked_kl_div(prob1, prob2, mask = None): +def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False): """ need to account for variable sequence lengths, therefore not using the built-in functional version """ kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1) + loss = masked_mean(kl_divs, mask) - if not exists(mask): - return kl_divs.mean() + if reduce_batch: + return loss.mean() - return masked_mean(kl_divs, mask).mean() + return loss def clipped_value_loss(values, rewards, old_values, clip): value_clipped = old_values + (values - old_values).clamp(-clip, clip) @@ -502,10 +503,14 @@ def learn( # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not - kl_div_loss = 0. + kl_penalty = 0. if self.kl_div_loss_weight > 0: - kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight + kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight + + # subtract the kl penalty from the rewards + + rewards = rewards - kl_penalty # handle non-pooled values @@ -536,7 +541,7 @@ def learn( # combine losses - loss = policy_loss.mean() + kl_div_loss + loss = policy_loss.mean() # update actor @@ -552,7 +557,7 @@ def learn( # calculate value loss and update value network separate from policy network - value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip) + value_loss = clipped_value_loss(values, rewards.detach(), old_values, self.value_clip) value_loss = value_loss.mean() self.print(f'critic_loss: {value_loss.item():.3f}') diff --git a/setup.py b/setup.py index cf87dc2..13aff00 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.1.4', + version = '0.2.0', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',