-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ML Decoder #2045
base: main
Are you sure you want to change the base?
Update ML Decoder #2045
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Currently have 3 variants: original with compatibility changes for pt 2.1+, a version identical to the original but with a performance fix that's giving ~35% training speedup, and a WIP reimplementation that updates styling and changes the decoder block implementation. The last one involves architectural changes that removes some of odd components from #1012 (residual on dropouts and queries, mlp residual location, norm locations). I'm in the process of testing these changes. @mrT23 do you have any comments? @rwightman are you aware of any pytorch transformer decoder/cross attention implementations that follow the style of this library or that I could reference? I inferred the cross attention impl based on the ViT attn impl and the block structure from the original impl, but I would rather this impl follow something standard (less the the self attn in the decoder), than introduce other odd implementation choices. Model compatibility is iffy, there are a few with odd architectures (combining multiple feature maps, distillation architectures, etc) that would be a pain to special case and probably won't be used and a few other more prominent architectures that don't work because they use nhwc. Overall seems like it would be difficult to maintain along with #2048. @rwightman Do you think a revised classifierhead that supports additional pooling and head formats would be better? There are quite a few structures that can be placed there (pool->ffn, various ml-decoder-like mechanisms). Since many models already have |
@fffffgggg54 curious what your goals are for this impl, what sort of applications, etc.
|
Goals are primarily performance, compatibility, and consistent styling with newer timm implementations. Legacy version provides support and improved performance and reimplementation attempts to match other timm models and removes I work almost exclusively with a dataset (danbooru) that poses a multi-label positive-unlabeled problem. I use MLDecoder for this sometimes and recently noticed the pt 2.1 and extensive model compat issues along the slow groupFC impl, odd dropout, etc, prompting both versions. In addition to PU-specific techniques (often in math-heavy papers that are a bit of a headache to read), some researchers focus on aspects of the model, often what comes after the backbone (GNNs, text towers, MLDecoder, activation functions). The labeling scheme of danbooru is set up such that the labels present in an image can be mapped out hierarchically, similar to a scene graph. This is also done internally via label implications. I'm also working on implementing a from-scratch impl of DependencyViT for this, not going well, hopefully can exploit the tree structure for this. I have a drop-in replacement for |
Added tests from my own testing script, will fail because there are models that don't work, 95 variants specifically, mostly due to distill/multiple feature maps/wrong input shape. The code to add the head is messy, universal head should fix. |
Experimental feature. Want to wait to merge this until a universal head is implemented. This and other things I'm working on are a pain to implement/use/add to timm without a universal head. |
allow external class embed (ex text embeddings of class descriptions), head version toggle
bf08a92
to
b927237
Compare
Update ML Decoder's
TransformerDecoderLayerOptimal
module to comply with whatnn.TransformerDecoder
expects. Current changes work with resnet50.add_ml_decoder_head
needs to be updated for other models. In my limited testing, the following case works with RegNet: