Skip to content

Commit

Permalink
Fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Nov 20, 2023
1 parent a768559 commit 48db516
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ Flax), PyTorch, and/or TensorFlow.
| [SegFormer](model_doc/segformer) ||||
| [SEW](model_doc/sew) ||||
| [SEW-D](model_doc/sew-d) ||||
| [SigLIP](model_doc/siglip) ||||
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) ||||
| [Speech2Text](model_doc/speech_to_text) ||||
| [SpeechT5](model_doc/speecht5) ||||
Expand Down
17 changes: 13 additions & 4 deletions src/transformers/models/siglip/configuration_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,21 @@ class SiglipTextConfig(PretrainedConfig):
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
Example:
Expand All @@ -87,6 +93,7 @@ class SiglipTextConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "siglip_text_model"

def __init__(
Expand Down Expand Up @@ -161,20 +168,22 @@ class SiglipVisionConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/siglip/image_processing_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class SiglipImageProcessor(BaseImageProcessor):
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
Expand Down
39 changes: 8 additions & 31 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import nn

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand All @@ -45,21 +46,6 @@
]


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/2021-03-07-siglip.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -149,8 +135,7 @@ class SiglipOutput(ModelOutput):
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of
[`SiglipVisionModel`].
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
Expand Down Expand Up @@ -442,9 +427,7 @@ def _init_weights(self, module):
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, SiglipMLP):
factor = self.config.initializer_factor
in_proj_std = (
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
)
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
Expand Down Expand Up @@ -627,18 +610,12 @@ def forward(
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
Expand Down Expand Up @@ -703,11 +680,11 @@ def forward(

hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)

# note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/siglip/processing_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SiglipProcessor(ProcessorMixin):
tokenizer ([`T5TokenizerFast`]):
The tokenizer is a required input.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
Expand Down

0 comments on commit 48db516

Please sign in to comment.