From 5e1c32ac7161ec750cc631206afdcf7fa947293f Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 27 May 2024 12:56:26 +0500 Subject: [PATCH 1/5] Initial attempt --- src/transformers/models/beit/modeling_beit.py | 79 +++++++++++++++++-- tests/models/beit/test_modeling_beit.py | 24 ++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index a9b38d4ee39066..59b4c8b54912e5 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -137,6 +137,7 @@ def __init__(self, config: BeitConfig) -> None: else: self.mask_token = None self.patch_embeddings = BeitPatchEmbeddings(config) + self.patch_size = config.patch_size num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -144,9 +145,52 @@ def __init__(self, config: BeitConfig) -> None: self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows the model to interpolate the pre-trained position encodings so that it can be used on + higher resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h = height // self.patch_size + w = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h, w = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: embeddings, (patch_height, patch_width) = self.patch_embeddings( - pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None + pixel_values, + self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None, + interpolate_pos_encoding, ) batch_size, seq_len, _ = embeddings.size() @@ -158,7 +202,11 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo cls_tokens = self.cls_token.expand(batch_size, -1, -1) if self.position_embeddings is not None: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) + else: + cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -191,13 +239,25 @@ def __init__(self, config): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.Tensor, + position_embedding: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] @@ -658,6 +718,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, BeitModelOutputWithPooling]: r""" @@ -680,7 +741,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos) + embedding_output, (patch_height, patch_width) = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -755,6 +818,7 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, MaskedLMOutput]: r""" @@ -800,6 +864,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -858,6 +923,7 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" @@ -872,6 +938,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1215,6 +1282,7 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, SemanticSegmenterOutput]: r""" @@ -1252,6 +1320,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 1010c6007d66d2..398af093db62a6 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -545,6 +545,30 @@ def test_post_processing_semantic_segmentation(self): expected_shape = torch.Size((160, 160)) self.assertEqual(segmentation[0].shape, expected_shape) + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "microsoft/beit-base-patch16-224-pt22k" + model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = BeitImageProcessor.from_pretrained(model_name) + inputs = processor(images=image, return_tensors="pt", size=480) + pixel_values = inputs.pixel_values.to(torch_device) + + # with interpolate_pos_encoding being False an exception should be raised with higher resolution + # images than what the model supports. + with torch.no_grad(): + with self.assertRaises(ValueError, msg="doesn't match model"): + model(pixel_values, interpolate_pos_encoding=False) + + # with interpolate_pos_encoding being True the model should process the higher resolution image + # successfully and produce the expected output. + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + expected_shape = torch.Size((1, 1200, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin): From 39af204a65ba3bbd940028aa1a1ebe5bc2000b1b Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 3 Jun 2024 18:51:36 +0500 Subject: [PATCH 2/5] Updates: PR suggestions --- src/transformers/models/beit/modeling_beit.py | 14 +++----------- tests/models/beit/test_modeling_beit.py | 3 ++- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 59b4c8b54912e5..266073d435992e 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -188,9 +188,7 @@ def forward( interpolate_pos_encoding: bool = False, ) -> torch.Tensor: embeddings, (patch_height, patch_width) = self.patch_embeddings( - pixel_values, - self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None, - interpolate_pos_encoding, + pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None ) batch_size, seq_len, _ = embeddings.size() @@ -243,7 +241,6 @@ def forward( self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None, - interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: @@ -251,13 +248,6 @@ def forward( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] @@ -667,6 +657,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 398af093db62a6..78ee4980f620c4 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -552,11 +552,12 @@ def test_inference_interpolate_pos_encoding(self): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") processor = BeitImageProcessor.from_pretrained(model_name) - inputs = processor(images=image, return_tensors="pt", size=480) + inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) pixel_values = inputs.pixel_values.to(torch_device) # with interpolate_pos_encoding being False an exception should be raised with higher resolution # images than what the model supports. + self.assertFalse(processor.do_center_crop) with torch.no_grad(): with self.assertRaises(ValueError, msg="doesn't match model"): model(pixel_values, interpolate_pos_encoding=False) From e2834dee81371bfa5d3bbc80e9755bd052a54114 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 4 Jun 2024 13:46:15 +0500 Subject: [PATCH 3/5] Interpolate the relative position bias when interpolate_pos_encoding is True --- src/transformers/models/beit/modeling_beit.py | 48 ++++++-- .../data2vec/modeling_data2vec_vision.py | 111 ++++++++++++++++-- tests/models/beit/test_modeling_beit.py | 3 +- .../data2vec/test_modeling_data2vec_vision.py | 26 ++++ 4 files changed, 169 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 266073d435992e..460dfe4bd82b39 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -138,6 +138,11 @@ def __init__(self, config: BeitConfig) -> None: self.mask_token = None self.patch_embeddings = BeitPatchEmbeddings(config) self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -187,6 +192,13 @@ def forward( bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> torch.Tensor: + _, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings, (patch_height, patch_width) = self.patch_embeddings( pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None ) @@ -201,7 +213,6 @@ def forward( cls_tokens = self.cls_token.expand(batch_size, -1, -1) if self.position_embeddings is not None: if interpolate_pos_encoding: - _, _, height, width = pixel_values.shape cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) else: cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] @@ -301,6 +312,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -315,7 +327,9 @@ def forward( # Add relative position bias if present. if self.relative_position_bias is not None: - attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + attention_scores = attention_scores + self.relative_position_bias( + interpolate_pos_encoding, attention_scores.shape[2] + ).unsqueeze(0) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -392,8 +406,11 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding + ) attention_output = self.output(self_outputs[0], hidden_states) @@ -457,12 +474,14 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -521,12 +540,21 @@ def __init__(self, config: BeitConfig, window_size: tuple) -> None: self.register_buffer("relative_position_index", relative_position_index, persistent=False) - def forward(self) -> torch.Tensor: + def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 ) # Wh*Ww,Wh*Ww,nH - return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias class BeitEncoder(nn.Module): @@ -558,6 +586,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None @@ -578,9 +607,13 @@ def forward( ) else: relative_position_bias = ( - self.relative_position_bias() if self.relative_position_bias is not None else None + self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) + if self.relative_position_bias is not None + else None + ) + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding ) - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) hidden_states = layer_outputs[0] @@ -743,6 +776,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 03b8170e6710b5..776e013584eb95 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -136,6 +136,12 @@ def __init__(self, config: Data2VecVisionConfig) -> None: else: self.mask_token = None self.patch_embeddings = Data2VecVisionPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -143,7 +149,55 @@ def __init__(self, config: Data2VecVisionConfig) -> None: self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows the model to interpolate the pre-trained position encodings so that it can be used on + higher resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h = height // self.patch_size + w = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h, w = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + _, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings, (patch_height, patch_width) = self.patch_embeddings( pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None ) @@ -157,7 +211,10 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo cls_tokens = self.cls_token.expand(batch_size, -1, -1) if self.position_embeddings is not None: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] + if interpolate_pos_encoding: + cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) + else: + cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -191,7 +248,11 @@ def __init__(self, config): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.Tensor, + position_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -252,6 +313,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -266,7 +328,9 @@ def forward( # Add relative position bias if present. if self.relative_position_bias is not None: - attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + attention_scores = attention_scores + self.relative_position_bias( + interpolate_pos_encoding, attention_scores.shape[2] + ).unsqueeze(0) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -345,8 +409,11 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding + ) attention_output = self.output(self_outputs[0], hidden_states) @@ -415,12 +482,14 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -480,12 +549,21 @@ def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None: self.register_buffer("relative_position_index", relative_position_index, persistent=False) - def forward(self) -> torch.Tensor: + def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 ) # Wh*Ww,Wh*Ww,nH - return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias # Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision @@ -518,6 +596,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None @@ -538,9 +617,13 @@ def forward( ) else: relative_position_bias = ( - self.relative_position_bias() if self.relative_position_bias is not None else None + self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) + if self.relative_position_bias is not None + else None + ) + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding ) - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) hidden_states = layer_outputs[0] @@ -670,6 +753,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]: r""" @@ -692,7 +776,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos) + embedding_output, (patch_height, patch_width) = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -700,6 +786,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) @@ -772,6 +859,7 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" @@ -786,6 +874,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1141,6 +1230,7 @@ def forward( labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, SemanticSegmenterOutput]: r""" @@ -1178,6 +1268,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 78ee4980f620c4..453d3ec033b8c6 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -545,7 +545,6 @@ def test_post_processing_semantic_segmentation(self): expected_shape = torch.Size((160, 160)) self.assertEqual(segmentation[0].shape, expected_shape) - @slow def test_inference_interpolate_pos_encoding(self): model_name = "microsoft/beit-base-patch16-224-pt22k" model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device) @@ -567,7 +566,7 @@ def test_inference_interpolate_pos_encoding(self): with torch.no_grad(): outputs = model(pixel_values, interpolate_pos_encoding=True) - expected_shape = torch.Size((1, 1200, 768)) + expected_shape = torch.Size((1, 1801, 768)) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 99cbd66fbbcf1b..3fffda02aa3be8 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -341,3 +341,29 @@ def test_inference_image_classification_head_imagenet_1k(self): expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]] self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2) + + def test_inference_interpolate_pos_encoding(self): + model_name = "facebook/data2vec-vision-base-ft1k" + model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to( + torch_device + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k") + inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) + pixel_values = inputs.pixel_values.to(torch_device) + + # with interpolate_pos_encoding being False an exception should be raised with higher resolution + # images than what the model supports. + self.assertFalse(processor.do_center_crop) + with torch.no_grad(): + with self.assertRaises(ValueError, msg="doesn't match model"): + model(pixel_values, interpolate_pos_encoding=False) + + # with interpolate_pos_encoding being True the model should process the higher resolution image + # successfully and produce the expected output. + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + expected_shape = torch.Size((1, 1801, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) From a7cd981006682084037093d97bf41f2fb21a9604 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 4 Jun 2024 13:54:30 +0500 Subject: [PATCH 4/5] Add slow tag for the added tests --- tests/models/beit/test_modeling_beit.py | 1 + tests/models/data2vec/test_modeling_data2vec_vision.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 453d3ec033b8c6..0fd17efaf67c7a 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -545,6 +545,7 @@ def test_post_processing_semantic_segmentation(self): expected_shape = torch.Size((160, 160)) self.assertEqual(segmentation[0].shape, expected_shape) + @slow def test_inference_interpolate_pos_encoding(self): model_name = "microsoft/beit-base-patch16-224-pt22k" model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device) diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 3fffda02aa3be8..fabf543c021a43 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -342,6 +342,7 @@ def test_inference_image_classification_head_imagenet_1k(self): expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]] self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2) + @slow def test_inference_interpolate_pos_encoding(self): model_name = "facebook/data2vec-vision-base-ft1k" model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to( From 2080a212201afe9e4371ecb5938d581e78c93e0f Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 6 Jun 2024 18:02:50 +0500 Subject: [PATCH 5/5] Add in DATA2VEC_VISION_INPUTS_DOCSTRING --- src/transformers/models/data2vec/modeling_data2vec_vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 776e013584eb95..4504701a1f3c8e 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -701,6 +701,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """