-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Whisper #693
Conversation
- setup import structure and files - implement mixins - implement WithAdapter classes for modeling
- create WhisperSdpaAttentionAdapters module - create WhisperAdapterModel - make style - add whisper to CONFIG_CLASS_KEYS_MAPPING - add whisper to model importstructure
- make style - add proof of concept head to head_utils - add whisper to parallel composition white list
…t and add documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good overall, thanks so much for working on it!
Did a first pass & left a couple of comments, mainly question to understand changes you made.
Co-authored-by: calpt <calpt@mail.de>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Have you already trained an adapter on a task to see that our implementation yields the expected results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
# Conflicts: # tests/test_adapter_heads.py
src/adapters/methods/reft.py
Outdated
# if cached indexing matrices are computed for different hidden_states size -> recompute | ||
cache_invalidated = False | ||
if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"): | ||
cache_invalidated = context.suff_idx.size(1) != seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch! should we check the full shape here since bsz and ddim might also change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, then we have all potential cases covered 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However I now realized that my checking logic was not correct; the indixing matrices and hidden_states do never have the same value at dim1:
- the hidden_states[1] represent the sequence length
- the suff_idx[1] represents the number of positions
We need to check for the actual values of suff_idx to see if the indexing values are out of bounds. I adapted the logic and added checks for the residual dimensions as well.
When the tests passed locally I will push the changes for review
This PR adds adapter support for the Whisper model from openai and builds upon work done previously in #572.
Key Additions:
Adapter Support for Whisper Model:
AudioClassificationHead Class:
Enhanced Head Functions:
Preprocessing Scripts for Audio Datasets: