Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Whisper #693

Merged
merged 72 commits into from
Aug 8, 2024
Merged
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
ad9fe2b
save current progress:
TimoImhof Apr 2, 2024
7ebb5e8
Merge branch 'adapter-hub:main' into dev/whisper
TimoImhof Apr 2, 2024
1e10398
Implement WhisperAdapterModel:
TimoImhof Apr 12, 2024
c775d43
add logger for Attention module
TimoImhof Apr 12, 2024
23f1c68
Add Whisper model to documentation
TimoImhof Apr 12, 2024
f4b3df8
Add WhisperDecoderWrapperAdaptersMixin:
TimoImhof Apr 12, 2024
dc40973
Add tests
TimoImhof Apr 12, 2024
a0a89ed
save progress
TimoImhof Apr 19, 2024
f36133c
save progress
TimoImhof Apr 25, 2024
f011bb3
overwrite get_input_samples method to fix tests requiring simple inpu…
TimoImhof Apr 26, 2024
08464c2
add support for speech samples with "input_features" as tensor name
TimoImhof Apr 26, 2024
32a6434
fix wrong input argument
TimoImhof Apr 26, 2024
d13b6d3
upload dev files for experiments
TimoImhof Apr 28, 2024
f5e4269
upload dev files for experiments
TimoImhof Apr 28, 2024
70f7651
update SpeechTestBase
TimoImhof Apr 28, 2024
e38cd9e
Add copy info and add flash attention
TimoImhof Apr 30, 2024
8587f0f
Changes:
TimoImhof Apr 30, 2024
909fecb
Changes:
TimoImhof Apr 30, 2024
fcaa21e
Delete dev dir
TimoImhof Apr 30, 2024
9bd7065
add TODOS
TimoImhof Apr 30, 2024
25171bb
make method more general
TimoImhof Apr 30, 2024
6405be4
add methods necessary for head usage
TimoImhof May 2, 2024
c52f0c3
Add TODO
TimoImhof May 2, 2024
cc00ed6
remove redundant code
TimoImhof May 2, 2024
24f72d6
add comment & enable all tests
TimoImhof May 2, 2024
182f5a5
Add special check for vision models
TimoImhof May 2, 2024
e44a482
make style
TimoImhof May 2, 2024
ca41958
add speech_classification head
TimoImhof May 5, 2024
158165b
Adapting tests:
TimoImhof May 8, 2024
0eadf6d
update dataset
TimoImhof May 8, 2024
d4117b7
residual updates:
TimoImhof May 8, 2024
091d947
Include adapters.init() support for:
TimoImhof May 14, 2024
5e8bf99
Adapt Testbase
TimoImhof May 15, 2024
c483068
Fixes:
TimoImhof May 15, 2024
480b4b6
Changes:
TimoImhof May 21, 2024
38bbb06
Add custom classification head
TimoImhof May 23, 2024
4ddc919
Fix embedding text:
TimoImhof May 23, 2024
4d6e9cc
Fix generation
TimoImhof May 24, 2024
e54f1c4
Fix composition and invertible adapters
TimoImhof May 25, 2024
1c6aebd
Merge branch 'main' into dev/whisper
TimoImhof May 26, 2024
980a3f4
Revert test changes:
TimoImhof Jun 4, 2024
11daca4
manually handle failing style checks:
TimoImhof Jun 4, 2024
63ca22a
- remove audio classification from WhisperAdapterModel head classes
TimoImhof Jun 5, 2024
4069c94
Remove redundant code:
TimoImhof Jun 7, 2024
6beba77
fix typo
TimoImhof Jun 7, 2024
faf54b6
fix conditional case and remove redundant code line
TimoImhof Jun 7, 2024
fcdc409
fix prompt tuning test
lenglaender Jun 11, 2024
e513263
Add ConversionTests and AudioClassificationMixin
TimoImhof Jun 11, 2024
acd5332
polish docs
TimoImhof Jun 18, 2024
c8427b0
polish docs
TimoImhof Jun 18, 2024
7e3e108
Fix import
TimoImhof Jun 18, 2024
69e3c99
Remove redundant files
TimoImhof Jun 18, 2024
af2ddc2
Update src/adapters/model_mixin.py
TimoImhof Jul 9, 2024
57b411a
Apply suggestions
TimoImhof Jul 9, 2024
61f3742
Merge remote-tracking branch 'origin/dev/whisper' into dev/whisper
TimoImhof Jul 9, 2024
f01b51e
Merge branch 'main' into dev/whisper
TimoImhof Jul 9, 2024
1f8573f
Fix failing test and refactor speech model case handling
TimoImhof Jul 10, 2024
8d04de1
Fix failing test
TimoImhof Jul 10, 2024
5b41382
Fix overwriting arguments
TimoImhof Jul 10, 2024
327381e
make style
TimoImhof Jul 10, 2024
a30bb6c
Address remaining comments, fix conversion test, correct documentatio…
TimoImhof Jul 24, 2024
ad47696
Revert forward function signature modification
TimoImhof Jul 27, 2024
6111b07
Merge branch 'adapter-hub:main' into dev/whisper
TimoImhof Jul 30, 2024
53b9cd9
make style
TimoImhof Jul 30, 2024
7526514
Remove redundant head - not supported by any model
TimoImhof Jul 30, 2024
0588db6
Add Future TODO for seq2seqtrainer
TimoImhof Aug 1, 2024
1751a25
Merge branch 'refs/heads/main' into dev/whisper
TimoImhof Aug 3, 2024
08911d8
Incorporate pyreft tests
TimoImhof Aug 3, 2024
88bd867
Add check for changing hidden_states size
TimoImhof Aug 3, 2024
3de2581
Adapt checking logic
TimoImhof Aug 4, 2024
89377f8
Merge branch 'refs/heads/main' into dev/whisper
TimoImhof Aug 4, 2024
5f4f20c
Fix attention classes and generation
TimoImhof Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,15 @@ def __init__(self, in_features: int, config: ReftConfig):

def _gather_adapted_states(self, hidden_states: torch.Tensor):
context = ForwardContext.get_context()
bsz, _, ddim = hidden_states.size()
bsz, seq_len, ddim = hidden_states.size()

# if cached indexing matrices are computed for different hidden_states size -> recompute
cache_invalidated = False
if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"):
cache_invalidated = context.suff_idx.size(1) != seq_len
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch! should we check the full shape here since bsz and ddim might also change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, then we have all potential cases covered 👍

Copy link
Contributor Author

@TimoImhof TimoImhof Aug 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However I now realized that my checking logic was not correct; the indixing matrices and hidden_states do never have the same value at dim1:

  • the hidden_states[1] represent the sequence length
  • the suff_idx[1] represents the number of positions

We need to check for the actual values of suff_idx to see if the indexing values are out of bounds. I adapted the logic and added checks for the residual dimensions as well.
When the tests passed locally I will push the changes for review


# no cached indexing matrices available -> compute now
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"):
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx") or cache_invalidated:
# read offsets & lengths from context
if hasattr(context, "seqlens"):
first_non_padding = context.offsets
Expand Down
Loading