Skip to content

Commit

Permalink
Remove cls_embed arg from forward/encode_image fns
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 23, 2023
1 parent dc463ae commit ae8333f
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def _encode_image(self, images, normalize: bool = True):
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize: bool = True, embed_cls: bool = True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
Expand All @@ -143,15 +142,14 @@ def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize: bool = True, embed_cls: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent

def forward(
self,
image,
text: Optional[torch.Tensor] = None,
embed_cls: bool = True,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
):
Expand All @@ -161,7 +159,7 @@ def forward(
if text is None:
return {"image_features": image_latent, "image_embs": image_embs}

text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
text_latent, token_embs = self._encode_text(text)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
Expand Down Expand Up @@ -222,7 +220,7 @@ def generate(

if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs = image,
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
Expand Down Expand Up @@ -267,7 +265,7 @@ def generate(
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

Expand Down Expand Up @@ -362,7 +360,6 @@ def _generate_beamsearch(
outputs = self(
model_inputs['images'],
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
)
Expand Down

0 comments on commit ae8333f

Please sign in to comment.