From 6a5068eeecddb53b64ebed5c19adea349293132c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Jul 2024 11:24:11 -0700 Subject: [PATCH] Fix AttentionPooler batch_first change, remove device arg from logit processor as it's very very new, move sot/eos to tensor beforehand --- src/open_clip/coca_model.py | 16 ++++++++++++---- src/open_clip/transformer.py | 5 ++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 618614e96..dda3faba5 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -78,6 +78,14 @@ def _build_text_decoder_tower( return decoder +def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: + if not isinstance(token_id, torch.Tensor): + if isinstance(token_id, int): + token_id = [token_id] + token_id = torch.tensor(token_id, device=device) + return token_id + + class CoCa(nn.Module): def __init__( self, @@ -218,12 +226,12 @@ def generate( device = image.device with torch.no_grad(): - sot_token_id = 49406 if sot_token_id is None else sot_token_id - eos_token_id = 49407 if eos_token_id is None else eos_token_id + sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) + eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ - MinLengthLogitsProcessor(min_seq_len, eos_token_id, device=device), + MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) @@ -248,7 +256,7 @@ def generate( pad_len = seq_len - output.shape[1] return torch.cat(( output, - torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * self.pad_id + torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id ), dim=1 ) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 84ac32ecc..4932abf2c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -200,10 +200,10 @@ def __init__( self.ln_k = norm_layer(context_dim) def forward(self, x: torch.Tensor): + N = x.shape[0] x = self.ln_k(x) - N = x.shape[1] q = self.ln_q(self.query) - out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] + out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] return out @@ -823,7 +823,6 @@ def __init__( output_dim: int = 512, batch_first: bool = True, ): - super().__init__( width=width, layers=layers,