Skip to content

Commit

Permalink
better decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
gpucce committed Aug 14, 2023
1 parent 52e7b3e commit 4f0b832
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def __init__(
):

super().__init__()

self.width = width
self.layers = layers
self.grad_checkpointing = False
Expand All @@ -681,21 +681,21 @@ def __init__(

self.resblocks = nn.ModuleList([])
self.cross_attn = nn.ModuleList([])

for l_idx in range(layers):

_, _r = divmod(l_idx, self.cross_step)
has_cross_attn = _r == 0
_, r = divmod(l_idx, self.cross_step)
has_cross_attn = r == 0

self.resblocks.append(
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
has_mlp=has_cross_attn or has_mlp,
has_mlp=(not has_cross_attn) or has_mlp,
)
)

Expand All @@ -711,15 +711,15 @@ def __init__(
is_cross_attention=True,
)
)

assert len(self.cross_attn) == n_cross_attn, "the number of cross attn is incorrect"

self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)

self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
self.does_full_decoding = does_full_decoding

if self.does_full_decoding:
self.num_pos = self.context_length
self.token_embedding = nn.Embedding(vocab_size, width)
Expand All @@ -728,7 +728,7 @@ def __init__(
self.num_pos = None
self.token_embedding = None
self.positional_embedding = None

self.output_tokens = output_tokens

self.init_parameters()
Expand All @@ -749,11 +749,13 @@ def build_attention_mask(self):
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
for resblock in self.resblocks:
if hasattr(resblock, 'mlp') and resblock.mlp is not None:
if hasattr(resblock.mlp.c_fc, 'int8_original_dtype'):
return resblock.mlp.c_fc.int8_original_dtype
return resblock.mlp.c_fc.weight.dtype

def forward(self, image_embs, text_embs):
seq_len = text_embs.shape[1]
Expand All @@ -766,8 +768,6 @@ def forward(self, image_embs, text_embs):
if image_embs is not None:
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND



# TODO: handle different cases better, currently
# differentiates coca from mammut based on image_embs
if image_embs is not None:
Expand Down Expand Up @@ -799,7 +799,7 @@ def forward(self, image_embs, text_embs):

if self.text_projection is not None:
logits = x @ self.text_projection

if self.output_tokens:
return logits, x

Expand Down

0 comments on commit 4f0b832

Please sign in to comment.