Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SigLIP #26522

Merged
merged 99 commits into from
Jan 8, 2024
Merged

Add SigLIP #26522

merged 99 commits into from
Jan 8, 2024

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Oct 1, 2023

What does this PR do?

This PR adds Google's new SigLIP model (CLIP with a better loss function). It's based on the Google Colab provided by the authors.

cc @patil-suraj feel free to take over this one

To do:

  • add SiglipTokenizer (or use T5Tokenizer ? The vocab is defined here)
  • add tests for the image processor, tokenizer and processor
  • add fast tokenizer and enable fast tokenizer tests => skip fast tokenizer for now, see branch add_siglip_fast_tokenizer
  • add loss function for training => won't do since various torch.distributed utilities would have to be incorporated
  • important one: make sure that weights of SiglipVisionModel can be properly loaded without from_pretrained complaining
  • make sure attention_mask is not passed for siglip checkpoints by updating model_input_names for checkpoints
  • set split_special_tokens=True? => no but users can pass this flag
  • transfer checkpoints, update organization name

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link

github-actions bot commented Nov 1, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fiinal review! Hope it helps

src/transformers/models/siglip/tokenization_siglip.py Outdated Show resolved Hide resolved
src/transformers/models/siglip/tokenization_siglip.py Outdated Show resolved Hide resolved
src/transformers/models/siglip/tokenization_siglip.py Outdated Show resolved Hide resolved
src/transformers/models/siglip/tokenization_siglip.py Outdated Show resolved Hide resolved
src/transformers/models/siglip/tokenization_siglip_fast.py Outdated Show resolved Hide resolved
src/transformers/models/siglip/test.py Outdated Show resolved Hide resolved
src/transformers/convert_slow_tokenizer.py Outdated Show resolved Hide resolved
tests/models/siglip/test_tokenization_siglip.py Outdated Show resolved Hide resolved
tests/models/siglip/test_tokenization_siglip.py Outdated Show resolved Hide resolved
@NielsRogge
Copy link
Contributor Author

NielsRogge commented Jan 2, 2024

@ArthurZucker I added 26590d2 for split_special_tokens=True which required me to overwrite some tests of tokenization_common.py. Could you have a look?

Also, this isn't supported by tokenizers yet right?

To me it feels a bit weird to have this behaviour by default to match the original implementation, since any original implementation won't ever keep special tokens.

@ArthurZucker ArthurZucker self-requested a review January 3, 2024 16:09
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this new model!
Feel free to merge, the last test would be a nice to have, only thing to adress is the padding max length that you force in the pipeline

src/transformers/models/siglip/__init__.py Outdated Show resolved Hide resolved
input_ids = self.tokenizer.encode("▁He is not ▁He")
self.assertEqual(input_ids, [37, 46, 44, 37, 2])
tokens = self.tokenizer.tokenize("▁He is not ▁He")
self.assertEqual(tokens, ["▁he", "▁is", "▁not", "▁he"]) # spaces are eaten by spm even if not start
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last thing is this !

@ArthurZucker ArthurZucker merged commit 3b742ea into huggingface:main Jan 8, 2024
23 checks passed
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* Add first draft

* Use appropriate gelu function

* More improvements

* More improvements

* More improvements

* Convert checkpoint

* More improvements

* Improve docs, remove print statements

* More improvements

* Add link

* remove unused masking function

* begin tokenizer

* do_lower_case

* debug

* set split_special_tokens=True

* Remove script

* Fix style

* Fix rebase

* Use same design as CLIP

* Add fast tokenizer

* Add SiglipTokenizer to init, remove extra_ids

* Improve conversion script

* Use smaller inputs in conversion script

* Update conversion script

* More improvements

* Add processor to conversion script

* Add tests

* Remove print statements

* Add tokenizer tests

* Fix more tests

* More improvements related to weight initialization

* More improvements

* Make more tests pass

* More improvements

* More improvements

* Add copied from

* Add canonicalize_text

* Enable fast tokenizer tests

* More improvements

* Fix most slow tokenizer tests

* Address comments

* Fix style

* Remove script

* Address some comments

* Add copied from to tests

* Add more copied from

* Add more copied from

* Add more copied from

* Remove is_flax_available

* More updates

* Address comment

* Remove SiglipTokenizerFast for now

* Add caching

* Remove umt5 test

* Add canonicalize_text inside _tokenize, thanks Arthur

* Fix image processor tests

* Skip tests which are not applicable

* Skip test_initialization

* More improvements

* Compare pixel values

* Fix doc tests, add integration test

* Add do_normalize

* Remove causal mask and leverage ignore copy

* Fix attention_mask

* Fix remaining tests

* Fix dummies

* Rename temperature and bias

* Address comments

* Add copied from to tokenizer tests

* Add SiglipVisionModel to auto mapping

* Add copied from to image processor tests

* Improve doc

* Remove SiglipVisionModel from index

* Address comments

* Improve docs

* Simplify config

* Add first draft

* Make it like mistral

* More improvements

* Fix attention_mask

* Fix output_attentions

* Add note in docs

* Convert multilingual model

* Convert large checkpoint

* Convert more checkpoints

* Add pipeline support, correct image_mean and image_std

* Use padding=max_length by default

* Make processor like llava

* Add code snippet

* Convert more checkpoints

* Set keep_punctuation_string=None as in OpenCLIP

* Set normalized=False for special tokens

* Fix doc test

* Update integration test

* Add figure

* Update organization

* Happy new year

* Use AutoModel everywhere

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
@VictorSanh
Copy link
Contributor

VictorSanh commented Jan 18, 2024

thanks for adding this!

is there a reason why processor(text=["hello bonjour", "bonjour"], return_tensors="pt", padding=True) does not return any attention mask?

Perhaps it refers to make sure attention_mask is not passed for siglip checkpoints by updating model_input_names for checkpoints but i am not sure i understand why

>>> processor.tokenizer(["hello bonjour", "bonjour"], padding=True, return_attention_mask=True)
{'input_ids': [[14647, 10048, 20852, 1], [10048, 20852, 1, 1]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 0]]}
>>> processor(text=["hello bonjour", "bonjour"], padding=True, return_attention_mask=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: __call__() got an unexpected keyword argument 'return_attention_mask'

it looks like return_attention_mask is not passed to the tokenizer in the call to the processor

@NielsRogge
Copy link
Contributor Author

Hi @VictorSanh, SigLIP was trained without attention_mask (said differently, the text encoder attends to all tokens, including padding tokens!). Hence I explicitly had to set model_input_names only to "input_ids" for the checkpoints on the hub such that the model will internally attend to all tokens.

We still provide the possibility to create an attention_mask if you want padding tokens to be ignored, although predictions with the existing checkpoints will be pretty bad as that's not how those were trained.

Regarding the return_attention_mask argument not being passed to the tokenizer, indeed that's not supported yet. I'll add it as part of #28578

@VictorSanh
Copy link
Contributor

got it, i didn't see the issue.
that's quite odd that attention mask was not used

@xenova
Copy link
Contributor

xenova commented Jan 19, 2024

@VictorSanh another thing to note (which tripped me up), is that you need to use padding='max_length'... otherwise, the output differs wildly (see here for more info).

@VictorSanh
Copy link
Contributor

interesting, thanks for the info

these are rather odd behaviors (in comparison to what other tokenizers & models behave). do you think we can display that info somewhere? in the doc or the model card for instance.

wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
* Add first draft

* Use appropriate gelu function

* More improvements

* More improvements

* More improvements

* Convert checkpoint

* More improvements

* Improve docs, remove print statements

* More improvements

* Add link

* remove unused masking function

* begin tokenizer

* do_lower_case

* debug

* set split_special_tokens=True

* Remove script

* Fix style

* Fix rebase

* Use same design as CLIP

* Add fast tokenizer

* Add SiglipTokenizer to init, remove extra_ids

* Improve conversion script

* Use smaller inputs in conversion script

* Update conversion script

* More improvements

* Add processor to conversion script

* Add tests

* Remove print statements

* Add tokenizer tests

* Fix more tests

* More improvements related to weight initialization

* More improvements

* Make more tests pass

* More improvements

* More improvements

* Add copied from

* Add canonicalize_text

* Enable fast tokenizer tests

* More improvements

* Fix most slow tokenizer tests

* Address comments

* Fix style

* Remove script

* Address some comments

* Add copied from to tests

* Add more copied from

* Add more copied from

* Add more copied from

* Remove is_flax_available

* More updates

* Address comment

* Remove SiglipTokenizerFast for now

* Add caching

* Remove umt5 test

* Add canonicalize_text inside _tokenize, thanks Arthur

* Fix image processor tests

* Skip tests which are not applicable

* Skip test_initialization

* More improvements

* Compare pixel values

* Fix doc tests, add integration test

* Add do_normalize

* Remove causal mask and leverage ignore copy

* Fix attention_mask

* Fix remaining tests

* Fix dummies

* Rename temperature and bias

* Address comments

* Add copied from to tokenizer tests

* Add SiglipVisionModel to auto mapping

* Add copied from to image processor tests

* Improve doc

* Remove SiglipVisionModel from index

* Address comments

* Improve docs

* Simplify config

* Add first draft

* Make it like mistral

* More improvements

* Fix attention_mask

* Fix output_attentions

* Add note in docs

* Convert multilingual model

* Convert large checkpoint

* Convert more checkpoints

* Add pipeline support, correct image_mean and image_std

* Use padding=max_length by default

* Make processor like llava

* Add code snippet

* Convert more checkpoints

* Set keep_punctuation_string=None as in OpenCLIP

* Set normalized=False for special tokens

* Fix doc test

* Update integration test

* Add figure

* Update organization

* Happy new year

* Use AutoModel everywhere

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
@amyeroberts
Copy link
Collaborator

amyeroberts commented Jan 22, 2024

@VictorSanh Behaviour and doc examples were updated in #28578

@VictorSanh
Copy link
Contributor

thank you!

@HugoLaurencon
Copy link
Contributor

Hi, could someone explain why you chose to use Bicubic interpolation over Bilinear ones for the resizing of the images? In the official BigVision repo, I find bilinear methods but not bicubic ones.
https://github.com/google-research/big_vision/blob/main/big_vision/pp/ops_image.py

AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
* Add first draft

* Use appropriate gelu function

* More improvements

* More improvements

* More improvements

* Convert checkpoint

* More improvements

* Improve docs, remove print statements

* More improvements

* Add link

* remove unused masking function

* begin tokenizer

* do_lower_case

* debug

* set split_special_tokens=True

* Remove script

* Fix style

* Fix rebase

* Use same design as CLIP

* Add fast tokenizer

* Add SiglipTokenizer to init, remove extra_ids

* Improve conversion script

* Use smaller inputs in conversion script

* Update conversion script

* More improvements

* Add processor to conversion script

* Add tests

* Remove print statements

* Add tokenizer tests

* Fix more tests

* More improvements related to weight initialization

* More improvements

* Make more tests pass

* More improvements

* More improvements

* Add copied from

* Add canonicalize_text

* Enable fast tokenizer tests

* More improvements

* Fix most slow tokenizer tests

* Address comments

* Fix style

* Remove script

* Address some comments

* Add copied from to tests

* Add more copied from

* Add more copied from

* Add more copied from

* Remove is_flax_available

* More updates

* Address comment

* Remove SiglipTokenizerFast for now

* Add caching

* Remove umt5 test

* Add canonicalize_text inside _tokenize, thanks Arthur

* Fix image processor tests

* Skip tests which are not applicable

* Skip test_initialization

* More improvements

* Compare pixel values

* Fix doc tests, add integration test

* Add do_normalize

* Remove causal mask and leverage ignore copy

* Fix attention_mask

* Fix remaining tests

* Fix dummies

* Rename temperature and bias

* Address comments

* Add copied from to tokenizer tests

* Add SiglipVisionModel to auto mapping

* Add copied from to image processor tests

* Improve doc

* Remove SiglipVisionModel from index

* Address comments

* Improve docs

* Simplify config

* Add first draft

* Make it like mistral

* More improvements

* Fix attention_mask

* Fix output_attentions

* Add note in docs

* Convert multilingual model

* Convert large checkpoint

* Convert more checkpoints

* Add pipeline support, correct image_mean and image_std

* Use padding=max_length by default

* Make processor like llava

* Add code snippet

* Convert more checkpoints

* Set keep_punctuation_string=None as in OpenCLIP

* Set normalized=False for special tokens

* Fix doc test

* Update integration test

* Add figure

* Update organization

* Happy new year

* Use AutoModel everywhere

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
@amyeroberts
Copy link
Collaborator

Hi, could someone explain why you chose to use Bicubic interpolation over Bilinear ones for the resizing of the images? In the official BigVision repo, I find bilinear methods but not bicubic ones. https://github.com/google-research/big_vision/blob/main/big_vision/pp/ops_image.py

@NielsRogge good motivation to fill out #28180

)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
Copy link
Contributor

@davidgxue davidgxue May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NielsRogge Quick question if you don't mind: why is self.num_positions = self.num_patches? CLIP and other vision transformers has +1 for num of positions. Is it because there's not CLASS embedding/token? I dug super deep and realized the original implementation from here google-research/vision_transformer#61 they add a CLASS embedding to keep it consistent as the generic Transformers architecture.

I thought well, maybe SigLIP is different. Maybe Google implemented it without the CLASS embedding this time. Hence why it doesn't have the +1. But from Google's repo (https://github.com/google-research/big_vision/blob/d0b297bbb8e073861d2d09ebd57ca46496357e36/big_vision/configs/proj/image_text/siglip_lit_coco.py#L81), this line has the pool type set to tok which if you go to their vit modeling file, adds the CLASS embedding still.

So... I am guessing it just so happens you didn't add the CLASS token/embedding in this implementation? But I am trying to figure out the reasoning behind it. Perhaps there's some code you followed that does that?
Validating my thoery: I also see you don't have self.class_embedding declared in the SiglipVisionEmbeddings class, whereas models like CLIP or VIT has it initialized for their embeddings class.

I am trying to work on #30579 for SigLIP, so trying to understand it better. If class embedding isn't added then the interpolate function would differ slightly. I think it maybe safe assume you are not adding the CLASS token just based on this implementation, but not 100% sure. Confirming to be safe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think SigLIP has a CLS token indeed.

@rootAvish
Copy link

rootAvish commented May 30, 2024

@NielsRogge If you don't mind answering a question on this:

add loss function for training => won't do since various torch.distributed utilities would have to be incorporated

The docs say to use the SigLIP loss from the open_clip repository. Going over the code from open_clip I see that that it is indeed implemented using a lot of functionality from torch.distributed, but I didn't quite understand why this one can't be implemented without torch.distributed, the function in open_clip also allows for a world size of 1 right (which is equivalent to not using torch.distributed)?

@NielsRogge
Copy link
Contributor Author

@rootAvish feel free to open a PR, I'm not sure it would be equivalent in that case (e.g. if you would then use the 🤗 Trainer API and run on multiple devices, are gradients synced in the same way?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.