Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClipSeg broken #34415

Closed
2 of 4 tasks
mcmonkey4eva opened this issue Oct 25, 2024 · 27 comments · Fixed by #34419
Closed
2 of 4 tasks

ClipSeg broken #34415

mcmonkey4eva opened this issue Oct 25, 2024 · 27 comments · Fixed by #34419
Labels

Comments

@mcmonkey4eva
Copy link

mcmonkey4eva commented Oct 25, 2024

System Info

irrelevant, happens on any env with updated transformers, replicated myself on Transformers v4.46.0 (current release on pip at time of writing)

Who can help?

@amyeroberts @manuelsh

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here's a fully contained python file if you want a lazy replication:

import torch
import numpy as np
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

# Create gray image at 512x512 with 32x32 black spot in the middle for clipseg to find
# Note: intentionally use a size that isn't the ClipSeg size just to prove that the input size is irrelevant
arr = np.ones((512, 512, 3), dtype=np.uint8) * 128
arr[240:272, 240:272] = 0

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

with torch.no_grad():
    processed = processor(text="the black spot", images=arr, return_tensors="pt", padding=True)
    print(f"{processed.pixel_values.shape}") # prints 1, 3, 352, 352
    mask = model(**processed)[0] # ValueError: Input image size (352*352) doesn't match model (224*224).

# Now do something with the mask, maybe save a PIL image,
# doesn't matter cause you can't get to here without an error

Expected behavior

Expected behavior: it works, as it did before updating transformers

But instead, it fails with Input image size (352*352) doesn't match model (224*224).
See my comment on the PR that broke it: https://github.com/huggingface/transformers/pull/32600/files#r1816732699

Secondarily, there's now a warning about one of the arguments (That, again, is in the docs and has been there for years):
UserWarning: The following named arguments are not valid for 'ViTImageProcessor.preprocess' and were ignored: 'padding'

@mcmonkey4eva
Copy link
Author

Note that temporary workaround for this error is just manually pip install transformers==4.45.0 to forcibly backdate it, since the error is only present in absolute latest 4.46.0 version

@hlky
Copy link
Contributor

hlky commented Oct 25, 2024

# interpolate_pos_encodiung false should return value error
with self.assertRaises(ValueError, msg="doesn't match model"):
with torch.no_grad():
model(**inputs, interpolate_pos_encoding=False)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)

Try

model(**processed, interpolate_pos_encoding=True)

@mcmonkey4eva
Copy link
Author

That does make it work.

Note that that is not a solution as (A) an argument shouldn't be required just to work at all here, and (B) that is incompatible with older versions of transformers, so it can't even work as a workaround code patch.
Also offhand not sure what the original code was doing (was it always interpolating? Or is interpolating wrong to do here?)

@hlky
Copy link
Contributor

hlky commented Oct 25, 2024

A) The argument does appear to be required in this case because Input image size (352*352) doesn't match model (224*224)., the purpose of interpolate_pos_encoding is to enable input sizes different than the model's.
B) You have a few options for a workaround: set a minimum version of transformers in your requirements.txt, use importlib.metadata.version/packaging.version to check version installed package version number and determine whether interpolate_pos_encoding=True should be used, use inspect.signature to check whether model.forward accepts interpolate_pos_encoding.

was it always interpolating

if embeddings.shape[1] != self.num_positions:
new_shape = int(math.sqrt(embeddings.shape[1] - 1))
embeddings = embeddings + self.interpolate_position_embeddings((new_shape, new_shape))
embeddings = embeddings.to(embeddings.dtype)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)

In some cases.

Imo these changes might have benefited from a deprecation warning, and documentation should have been updated, but I'll leave it to maintainers to make further comment.

@mcmonkey4eva
Copy link
Author

A) The argument does appear to be required in this case because Input image size (352352) doesn't match model (224224)., the purpose of interpolate_pos_encoding is to enable input sizes different than the model's.

I think you missed the point: the input image is not 352. It's 512x512 in my above example, but could be anything. ClipSeg's own code within transformers is resizing it to 352, then complaining that it's the wrong size.

@hlky
Copy link
Contributor

hlky commented Oct 25, 2024

Input image size in the ValueError refers to the size specified in the processor, 352x352. The image itself, 512x512 in your example, is processed to 352x352. Maybe the ValueError should specify pixel values size instead, but I think the meaning is pretty clear. 352x352 is larger than the model's size, 224x224, so interpolate_pos_encoding is required.

@NielsRogge
Copy link
Contributor

NielsRogge commented Oct 25, 2024

Thanks for reporting, I contributed CLIPSeg to the Transformers library some years ago. There was no interpolate_pos_encoding method back then.

I'm personally not a fan of just blindly adding this interpolate_pos_encoding method to each model without checking everything still works as expected. For CLIPSeg, the authors already defined this themselves. It's recommended to use CLIPSegProcessor to prepare the inputs for you and perform a forward pass.

Edit, yes they just introduced a breaking change:

outputs = model(**inputs, interpolate_pos_encoding=True)

@NielsRogge
Copy link
Contributor

Anyway, I've opened a PR to revert it: #34419

@manuelsh
Copy link
Contributor

manuelsh commented Oct 25, 2024

All vision models, included the CLIP family models, have been modified (or are in the process to be) with the same interpolate function and with interpolate_pos_encoding=True required in the signature (#30579).

It is true that ClipSeg already had a function doing that, which works in the same way as the current function (both use bicubic interpolation with nn.functional.interpolate and align_corners=False). Also, previously there was no need to add interpolate_pos_encoding=True in the signature of ClipSeg.

But this has not been "blindly" substituted. The benefits of the new interpolate_pos_encoding versus the previous interpolate_position_embeddings in ClipSeg are:

  1. Ensures compatibility with TorchScript and dynamic input shapes (this came from 🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models #33226 ),
  2. Includes an early return condition to avoid unnecessary interpolation when the number of patches hasn't changed and the image is square,
  3. Includes docstrings explaining the method and where is adapted from.

I believe the solution is not to revert back to the previous function, but possibly to remove the need to add interpolate_pos_encoding=True to guarantee code compatibility with a deprecation warning, as @hlky suggested, or being explicit in the error message about the need to add interpolate_pos_encoding=True in the signature.

@amyeroberts what do you think?

@mcmonkey4eva
Copy link
Author

If the behavior is identical to previous with the arg enabled, it should probably be enabled by default?

Again, I need to reiterate: Running clipseg as intended will always produce different shapes that require interpolation. It is explicitly erroneous to have that argument disabled.

@manuelsh
Copy link
Contributor

@mcmonkey4eva why it is "explicitly erroneous to have the argument disabled"?

Note that the positional encoding introduced affects only the input image, not the output segmentation of ClipSeg.

@mcmonkey4eva
Copy link
Author

I am confused why I have to keep repeating this:

If you run the code as intended, it errors. The code does not work if you do not have this interpolation enabled. The size difference that it needs to interpolate between it produced by clipseg itself not by the caller code.

Again, the input image can be any size you want, the HF Transformers ClipSeg code is what resizes it to 352, and then complains that the 352 it output is not the 224 it expects.

@hlky
Copy link
Contributor

hlky commented Oct 26, 2024

@mcmonkey4eva You appear to be confused by the currently incorrect documentation, indeed running the code from the incorrect documentation produces an error, it is unfortunate the documentation was not updated to include interpolate_pos_encoding=True.

Please refer to the comments above for a further explanation of why this change was made and the proposed solutions i.e. updating documentation, adding a deprecation warning, setting the default interpolate_pos_encoding to True, etc

In the meantime, you can easily resolve this in your integration with one of the following options:

  • set a minimum version of transformers in your requirements.txt
transformers>=4.46.0
  • use importlib.metadata.version/packaging.version to check version installed package version number and determine whether interpolate_pos_encoding=True should be used
import torch
import numpy as np
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

import importlib.metadata
import packaging

arr = np.ones((512, 512, 3), dtype=np.uint8) * 128
arr[240:272, 240:272] = 0

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

kwargs = (
    {"interpolate_pos_encoding": True}
    if packaging.version.parse(importlib.metadata.version("transformers"))
    >= packaging.version.Version("4.46.0")
    else {}
)

with torch.no_grad():
    processed = processor(
        text="the black spot", images=arr, return_tensors="pt", padding=True
    )
    print(f"{processed.pixel_values.shape}")
    mask = model(**processed, **kwargs)[0]
  • use inspect.signature to check whether model.forward accepts interpolate_pos_encoding
import torch
import numpy as np
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

import inspect

arr = np.ones((512, 512, 3), dtype=np.uint8) * 128
arr[240:272, 240:272] = 0

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

kwargs = (
    {"interpolate_pos_encoding": True}
    if "interpolate_pos_encoding" in inspect.signature(model.forward).parameters
    else {}
)

with torch.no_grad():
    processed = processor(
        text="the black spot", images=arr, return_tensors="pt", padding=True
    )
    print(f"{processed.pixel_values.shape}")
    mask = model(**processed, **kwargs)[0]

@manuelsh
Copy link
Contributor

Hi, @mcmonkey4eva is not confused because of that, as this has been added to the documentation too, please check https://huggingface.co/docs/transformers/v4.46.0/en/model_doc/clipseg#transformers.CLIPSegModel where the Parameter interpolate_pos_encoding is documented. I think in the release notes of the new version this could have been more explicit, or we could have added some deprecation warning.

@mcmonkey4eva, as mentioned before, in the new version you need to add interpolate_pos_encoding=True for it to run with different resolutions as the one pre trained and this will solve your problem.

You can also refer to here for an explanation of the changes.

@mcmonkey4eva
Copy link
Author

@mcmonkey4eva, as mentioned before, in the new version you need to add interpolate_pos_encoding=True for it to run with different resolutions as the one pre trained and this will solve your problem.

Please by god read what I wrote above repeatedly. This is an incredibly frustrating attempt at communication.

@neggles
Copy link
Contributor

neggles commented Oct 27, 2024

Let me just lay this out nice and neatly so that everyone can get on the same page.

(I'll use "pos embed" to refer to the positional encoding/embedding just for brevity's sake)

  • ClipSeg, like most ViTs, accepts 224x224px input images unless the pos embed is interpolated to account for a larger input dimension
  • In the original implementation the ClipSeg authors rescaled inputs to 352x352px before feeding them into the model
    • In keeping with this, Transformers' ImageProcessor for ClipSeg defaults to 352x352px output resolution
  • The original ClipSeg implementation implicitly interpolates the pos embed for whatever input image size it's given.
    • The original authors intended for pos embed interpolation to always be enabled, given the unconditional interpolation in the original code and use of a resolution that cannot otherwise be accepted by the model.
  • Transformers' original implementation dutifully and faithfully replicated this behaviour:
    def interpolate_position_embeddings(self, new_size):
    if len(new_size) != 2:
    raise ValueError("new_size should consist of 2 values")
    num_patches_one_direction = int(self.num_patches**0.5)
    # we interpolate the position embeddings in 2D
    a = self.position_embedding.weight[1:].T.view(
    1, self.config.hidden_size, num_patches_one_direction, num_patches_one_direction
    )
    b = (
    nn.functional.interpolate(a, new_size, mode="bicubic", align_corners=False)
    .squeeze(0)
    .view(self.config.hidden_size, new_size[0] * new_size[1])
    .T
    )
    result = torch.cat([self.position_embedding.weight[:1], b])
    return result

This was the status quo for quite some time.

  • adding positional encoder changes and tests #32600 replaces the various ViT models' bespoke pos embed interpolation methods/functions with a single uniform implementation
  • This implementation is functionally equivalent to the one ClipSeg (and essentially all ViTs) were already using, so unifying these implementations makes sense.
  • However, the unified impl differs from the previous ClipSeg-specific impl in one major aspect: it is disabled by default.
  • Essentially all code using this model class is likely reliant on the existing behaviour
    • as mentioned above, the model authors clearly intended for it to be enabled by default and there is no reason to not enable interpolation in normal usage.

This is a breaking user-facing API change that is, IMO, entirely unjustified. Merely updating the docs to say "you have to pass interpolate_pos_encoding=True now" is not a satisfactory solution; the new argument cannot be passed to the model under old transformers versions without raising an exception, placing undue burden on users who would like to write code that is compatible with a range of transformers versions.

The question is not "how do i get my code to work now"; @mcmonkey4eva certainly doesn't need help with that given the codebases he maintains. The question is "Can we fix this unnecessary breakage of existing code and unintuitive change in default behaviour?"

The simple and straightforward solution here would be to just change the default for interpolate_pos_encoding to True in ClipSeg's forward() definition.

A slightly more complicated one would be to make the default value of this arg a model configuration property, with the default overridden in the model class.

@hlky
Copy link
Contributor

hlky commented Oct 27, 2024

Thanks for the thorough recap of the fundamentals, @neggles! In that vain, I'll make some further points:

#32600 replaces the various ViT models' bespoke pos embed interpolation methods/functions with a single uniform implementation
This implementation is functionally equivalent to the one ClipSeg (and essentially all ViTs) were already using, so unifying these implementations makes sense.

That PR added the function to various ViT models and replaced it only in ClipSeg. This was part of a greater project #30579 to enable dynamic resolution input for more vision models. Other benefits are explained here, namely TorchScript compatability.

This is a breaking user-facing API change

Yes, as we've said it probably should have been accompanied by a deprecation warning.

the new argument cannot be passed to the model under old transformers versions without raising an exception, placing undue burden on users who would like to write code that is compatible with a range of transformers versions.
The question is not "how do i get my code to work now"; @mcmonkey4eva certainly doesn't need help with that given the codebases he maintains

The burden of workarounds is simply part and parcel of being a developer, Transformers itself has several. If you'd like the codebase to work in the meantime, please refer to the solutions here

The question is "Can we fix this unnecessary breakage of existing code and unintuitive change in default behaviour?"

Yes, please refer to the proposed solutions here and here.

The simple and straightforward solution here would be to just change the default for interpolate_pos_encoding to True in ClipSeg's forward() definition.

Indeed, this is one of the proposed solutions.

@NielsRogge
Copy link
Contributor

NielsRogge commented Oct 27, 2024

Good take @neggles, I agree with you! Thanks also @manuelsh for clarifying the background.

I believe making interpolate_pos_encoding default to True would be the best solution here.

However @manuelsh I saw the logits were updated in the integration test, that should not be the case, logits should remain the same to ensure backwards compatibility when passing interpolate_pos_encoding=True.

@neggles
Copy link
Contributor

neggles commented Oct 27, 2024

@hlky Fair point! I was on my phone so I didn't have the best view of that PR, but my comments still very much apply to the combination of #32600 and #30579 taken together 😝

I'm quite aware of the benefits of the change - and entirely in favor of the overall change, FWIW - but there appeared to be some miscommunication W.R.T. @mcmonkey4eva's purpose with opening this PR, so it seemed like a good idea to try and clear that up.

The burden of workarounds is simply part and parcel of being a developer, Transformers itself has several

Well yeah, but a lot of downstream users of this library are data scientists, not developers, and just because python ML library dependency management is already a dumpster fire doesn't mean we should make it worse without a good reason! One of the stated goals of transformers is to (try to) be simple, straightforward, and user-friendly, and this change is the antithesis of that.

@NielsRogge Thanks! I agree. Seems like the lowest-effort lowest-impact most-obvious choice, and as a bonus it won't cause any problems for anyone who does happen to have implemented a workaround in the meantime.

FWIW, it looks like the actual logit change in #32600 was fairly insignificant:

         expected_masks_slice = torch.tensor(
-            [[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
+            [[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]]
         ).to(torch_device)

Probably just down to rounding changes, chasing that down is maybe not entirely worth the effort 🤷‍♀️

@manuelsh
Copy link
Contributor

@mcmonkey4eva, I understand the frustration with the breaking change in ClipSeg requiring an extra argument. I’m addressing this and appreciate everyone’s pragmatic solution to default interpolate_pos_encoding=True.

What’s unclear is whether we’re keeping the new interpolate_pos_encoding function or reverting to the previous one in ClipSeg. I believe the new one is best (consistent across ViT models, functionally equivalent and improved) yet I see in @NielsRogge's PR #34419 we might revert. @NielsRogge, is there another reason for keeping the old function?

@NielsRogge
Copy link
Contributor

@manuelsh we're usually pretty strict on logits matching, up to atol=1e-4 to 1e-6.

So perhaps we could modify the method to keep the old behaviour of interpolation (which matches the original one).

@ArthurZucker
Copy link
Collaborator

Hey all! 🤗
Sorry that things got heated when this is entirely on our side.

This is a breaking user-facing API change that is, IMO, entirely unjustified. Merely updating the docs to say "you have to pass interpolate_pos_encoding=True now" is not a satisfactory solution; the new argument cannot be passed to the model under old transformers versions without raising an exception, placing undue burden on users who would like to write code that is compatible with a range of transformers versions.

I completely agree with you, and am sorry that this escaped our reviews, we'll patch this ASAP.

@NielsRogge I think logits depend on the hardware you are using, the one making the PR might not have the same, and the difference is acceptable!

@NielsRogge
Copy link
Contributor

Hi @manuelsh the issue is fixed, apologies for my comment earlier, and thanks for the great explainer.

@manuelsh
Copy link
Contributor

manuelsh commented Nov 3, 2024

No worries @NielsRogge, I am glad we landed on a good solution.

@oliverban
Copy link

I have 4.46.2 and this is still broken, is that to be expected? Which version has the fix? :)

@xdew77
Copy link

xdew77 commented Nov 17, 2024

I have this error in Pinokio ComfyUI. I understand I need to install (downgrade) to transformers==4.45.0. I did that with anaconda and pip install transformers==4.45.0 but in E:\pinokio\api\comfy.git\app or where in which folder should it be installed? I also didn't find any test_modeling_clipseg.py, only a modeling_clipseg.py.

@xdew77
Copy link

xdew77 commented Nov 17, 2024

Note that temporary workaround for this error is just manually pip install transformers==4.45.0 to forcibly backdate it, since the error is only present in absolute latest 4.46.0 version

In which folder do I have to run this command in cmd or anaconda "pip install transformers==4.45.0" (I use PINOKIO ComfyUI)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants