diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 3bfd57d4da8041..744f4044d91d09 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -265,11 +265,53 @@ def __init__(self, config: SiglipVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs) + that allows the model to interpolate the pre-trained position encodings such that it can be usable on + higher resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -564,6 +606,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -601,6 +645,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -848,6 +894,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -859,7 +906,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -935,6 +982,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -965,6 +1013,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1055,6 +1104,7 @@ def get_image_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -1092,6 +1142,7 @@ def get_image_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] @@ -1110,6 +1161,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, SiglipOutput]: r""" Returns: @@ -1152,6 +1204,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( @@ -1226,6 +1279,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[tuple, ImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1271,6 +1325,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 8880168484eced..e0a4825d0ede60 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -687,3 +687,25 @@ def test_inference(self): probs = torch.sigmoid(logits_per_image) # these are the probabilities expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device) self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipModel.from_pretrained(model_name).to(torch_device) + + # 640 x 480 image + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640}) + + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the shape + # patch size = 16 + # batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size + expected_shape = torch.Size((1, 1200, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)