From 7079e7e1279d21cb0073a042bd538e25a901cbb3 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Thu, 26 Oct 2023 17:11:09 +0200 Subject: [PATCH] fix coca training --- src/open_clip/coca_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 272b2cc06..50c231203 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -160,6 +160,7 @@ def forward( text: Optional[torch.Tensor] = None, image_latent: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, + is_training=True ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) @@ -170,7 +171,9 @@ def forward( text_latent, token_embs = self._encode_text(text) # TODO: add assertion to avoid bugs? - labels = text[:, -token_embs.shape[1]:] + labels = text[:, 1:] + if is_training: + token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) out_dict = { @@ -276,7 +279,7 @@ def generate( while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs)["logits"][:, -1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, is_training=False)["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id @@ -372,7 +375,8 @@ def _generate_beamsearch( model_inputs['images'], model_inputs['text'], image_latent=image_latent, - image_embs=image_embs + image_embs=image_embs, + is_training=False ) for beam_group_idx in range(num_beam_groups):