Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dynamic resolution input for Beit #31053

Merged
merged 5 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,60 @@ def __init__(self, config: BeitConfig) -> None:
else:
self.mask_token = None
self.patch_embeddings = BeitPatchEmbeddings(config)
self.patch_size = config.patch_size
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
else:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
pixel_values,
self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None,
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
interpolate_pos_encoding,
)
batch_size, seq_len, _ = embeddings.size()

Expand All @@ -158,7 +202,11 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo

cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
if interpolate_pos_encoding:
_, _, height, width = pixel_values.shape
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]

embeddings = torch.cat((cls_tokens, embeddings), dim=1)

Expand Down Expand Up @@ -191,13 +239,25 @@ def __init__(self, config):

self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)

if not interpolate_pos_encoding:
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)

embeddings = self.projection(pixel_values)
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]

Expand Down Expand Up @@ -658,6 +718,7 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, BeitModelOutputWithPooling]:
r"""
Expand All @@ -680,7 +741,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
embedding_output, (patch_height, patch_width) = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
embedding_output,
Expand Down Expand Up @@ -755,6 +818,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
Expand Down Expand Up @@ -800,6 +864,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -858,6 +923,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
Expand All @@ -872,6 +938,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1215,6 +1282,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmenterOutput]:
r"""
Expand Down Expand Up @@ -1252,6 +1320,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
24 changes: 24 additions & 0 deletions tests/models/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,30 @@ def test_post_processing_semantic_segmentation(self):
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)

@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "microsoft/beit-base-patch16-224-pt22k"
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = BeitImageProcessor.from_pretrained(model_name)
inputs = processor(images=image, return_tensors="pt", size=480)
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
pixel_values = inputs.pixel_values.to(torch_device)

# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
Comment on lines +562 to +563
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure this still holds if anything happens upstream and to make things explicit, could you add the following above:

self.assertFalse(processor.do_center_crop)


# with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output.
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)

expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)


@require_torch
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
Expand Down