Skip to content

Commit

Permalink
Add dynamic resolution input/interpolate position embedding to SigLIP (
Browse files Browse the repository at this point in the history
…huggingface#30719)

* Add interpolate positional encoding to siglip

* Change # of patches for siglip interpolation test

* fix formatting

* Apply nit suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and zucchini-nlp committed May 10, 2024
1 parent 77703ec commit 9baa34c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
61 changes: 58 additions & 3 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,53 @@ def __init__(self, config: SiglipVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
that allows the model to interpolate the pre-trained position encodings such that it can be usable on
higher resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1]
num_positions = position_embeddings.shape[1]
if num_patches == num_positions and height == width:
return position_embeddings

dim = embeddings.shape[-1]
height = height // self.patch_size
width = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1

patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
_, _, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)

embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand Down Expand Up @@ -564,6 +606,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -601,6 +645,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -848,6 +894,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -859,7 +906,7 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
Expand Down Expand Up @@ -935,6 +982,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand Down Expand Up @@ -965,6 +1013,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)


Expand Down Expand Up @@ -1055,6 +1104,7 @@ def get_image_features(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand Down Expand Up @@ -1092,6 +1142,7 @@ def get_image_features(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

pooled_output = vision_outputs[1]
Expand All @@ -1110,6 +1161,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1152,6 +1204,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

text_outputs = self.text_model(
Expand Down Expand Up @@ -1226,6 +1279,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1271,6 +1325,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

sequence_output = outputs[0]
Expand Down
22 changes: 22 additions & 0 deletions tests/models/siglip/test_modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,25 @@ def test_inference(self):
probs = torch.sigmoid(logits_per_image) # these are the probabilities
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))

@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name).to(torch_device)

# 640 x 480 image
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})

inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)

# verify the shape
# patch size = 16
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
expected_shape = torch.Size((1, 1200, 768))

self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)

0 comments on commit 9baa34c

Please sign in to comment.