diff --git a/palm_rlhf_pytorch/ppo.py b/palm_rlhf_pytorch/ppo.py index bdfac8a..a6f2461 100644 --- a/palm_rlhf_pytorch/ppo.py +++ b/palm_rlhf_pytorch/ppo.py @@ -219,7 +219,9 @@ def exists(val): return val is not None def default(val, d): - return val if exists(val) else d + if exists(val): + return val + return d() if callable(d) else d def masked_normalize(t, eps = 1e-5, mask = None, dim = None): dim = default(dim, tuple(range(t.ndim))) @@ -630,7 +632,7 @@ def train( sequence = rearrange(sequence, 'n -> 1 n') prompt_mask = rearrange(prompt_mask, 'n -> 1 n') - mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) + mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device)) reward = self.reward_model( sequence,