diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 50c231203..b4faa061d 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -160,7 +160,7 @@ def forward( text: Optional[torch.Tensor] = None, image_latent: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, - is_training=True + output_labels: bool = True, ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) @@ -170,9 +170,10 @@ def forward( text_latent, token_embs = self._encode_text(text) - # TODO: add assertion to avoid bugs? - labels = text[:, 1:] - if is_training: + # FIXME this isn't an ideal solution, would like to improve -RW + labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None + if output_labels: + # align text_embs and thus logits with labels for teacher-forcing caption loss token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) @@ -180,9 +181,10 @@ def forward( "image_features": image_latent, "text_features": text_latent, "logits": logits, - "labels": labels, "logit_scale": self.logit_scale.exp() } + if labels is not None: + out_dict["labels"] = labels if self.logit_bias is not None: out_dict["logit_bias"] = self.logit_bias return out_dict @@ -245,8 +247,11 @@ def generate( logit_processor=logit_processor, ) if fixed_output_length and output.shape[1] < seq_len: - return torch.cat( - (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + pad_len = seq_len - output.shape[1] + return torch.cat(( + output, + torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * self.pad_id + ), dim=1 ) return output @@ -272,14 +277,19 @@ def generate( if num_dims == 1: text = text[None, :] - cur_len = text.shape[1] self.eval() out = text while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs, is_training=False)["logits"][:, -1] + logits = self( + image, + x, + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + )["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id @@ -376,7 +386,7 @@ def _generate_beamsearch( model_inputs['text'], image_latent=image_latent, image_embs=image_embs, - is_training=False + output_labels=False, ) for beam_group_idx in range(num_beam_groups):