Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Whisper #693

Merged
merged 72 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
ad9fe2b
save current progress:
TimoImhof Apr 2, 2024
7ebb5e8
Merge branch 'adapter-hub:main' into dev/whisper
TimoImhof Apr 2, 2024
1e10398
Implement WhisperAdapterModel:
TimoImhof Apr 12, 2024
c775d43
add logger for Attention module
TimoImhof Apr 12, 2024
23f1c68
Add Whisper model to documentation
TimoImhof Apr 12, 2024
f4b3df8
Add WhisperDecoderWrapperAdaptersMixin:
TimoImhof Apr 12, 2024
dc40973
Add tests
TimoImhof Apr 12, 2024
a0a89ed
save progress
TimoImhof Apr 19, 2024
f36133c
save progress
TimoImhof Apr 25, 2024
f011bb3
overwrite get_input_samples method to fix tests requiring simple inpu…
TimoImhof Apr 26, 2024
08464c2
add support for speech samples with "input_features" as tensor name
TimoImhof Apr 26, 2024
32a6434
fix wrong input argument
TimoImhof Apr 26, 2024
d13b6d3
upload dev files for experiments
TimoImhof Apr 28, 2024
f5e4269
upload dev files for experiments
TimoImhof Apr 28, 2024
70f7651
update SpeechTestBase
TimoImhof Apr 28, 2024
e38cd9e
Add copy info and add flash attention
TimoImhof Apr 30, 2024
8587f0f
Changes:
TimoImhof Apr 30, 2024
909fecb
Changes:
TimoImhof Apr 30, 2024
fcaa21e
Delete dev dir
TimoImhof Apr 30, 2024
9bd7065
add TODOS
TimoImhof Apr 30, 2024
25171bb
make method more general
TimoImhof Apr 30, 2024
6405be4
add methods necessary for head usage
TimoImhof May 2, 2024
c52f0c3
Add TODO
TimoImhof May 2, 2024
cc00ed6
remove redundant code
TimoImhof May 2, 2024
24f72d6
add comment & enable all tests
TimoImhof May 2, 2024
182f5a5
Add special check for vision models
TimoImhof May 2, 2024
e44a482
make style
TimoImhof May 2, 2024
ca41958
add speech_classification head
TimoImhof May 5, 2024
158165b
Adapting tests:
TimoImhof May 8, 2024
0eadf6d
update dataset
TimoImhof May 8, 2024
d4117b7
residual updates:
TimoImhof May 8, 2024
091d947
Include adapters.init() support for:
TimoImhof May 14, 2024
5e8bf99
Adapt Testbase
TimoImhof May 15, 2024
c483068
Fixes:
TimoImhof May 15, 2024
480b4b6
Changes:
TimoImhof May 21, 2024
38bbb06
Add custom classification head
TimoImhof May 23, 2024
4ddc919
Fix embedding text:
TimoImhof May 23, 2024
4d6e9cc
Fix generation
TimoImhof May 24, 2024
e54f1c4
Fix composition and invertible adapters
TimoImhof May 25, 2024
1c6aebd
Merge branch 'main' into dev/whisper
TimoImhof May 26, 2024
980a3f4
Revert test changes:
TimoImhof Jun 4, 2024
11daca4
manually handle failing style checks:
TimoImhof Jun 4, 2024
63ca22a
- remove audio classification from WhisperAdapterModel head classes
TimoImhof Jun 5, 2024
4069c94
Remove redundant code:
TimoImhof Jun 7, 2024
6beba77
fix typo
TimoImhof Jun 7, 2024
faf54b6
fix conditional case and remove redundant code line
TimoImhof Jun 7, 2024
fcdc409
fix prompt tuning test
lenglaender Jun 11, 2024
e513263
Add ConversionTests and AudioClassificationMixin
TimoImhof Jun 11, 2024
acd5332
polish docs
TimoImhof Jun 18, 2024
c8427b0
polish docs
TimoImhof Jun 18, 2024
7e3e108
Fix import
TimoImhof Jun 18, 2024
69e3c99
Remove redundant files
TimoImhof Jun 18, 2024
af2ddc2
Update src/adapters/model_mixin.py
TimoImhof Jul 9, 2024
57b411a
Apply suggestions
TimoImhof Jul 9, 2024
61f3742
Merge remote-tracking branch 'origin/dev/whisper' into dev/whisper
TimoImhof Jul 9, 2024
f01b51e
Merge branch 'main' into dev/whisper
TimoImhof Jul 9, 2024
1f8573f
Fix failing test and refactor speech model case handling
TimoImhof Jul 10, 2024
8d04de1
Fix failing test
TimoImhof Jul 10, 2024
5b41382
Fix overwriting arguments
TimoImhof Jul 10, 2024
327381e
make style
TimoImhof Jul 10, 2024
a30bb6c
Address remaining comments, fix conversion test, correct documentatio…
TimoImhof Jul 24, 2024
ad47696
Revert forward function signature modification
TimoImhof Jul 27, 2024
6111b07
Merge branch 'adapter-hub:main' into dev/whisper
TimoImhof Jul 30, 2024
53b9cd9
make style
TimoImhof Jul 30, 2024
7526514
Remove redundant head - not supported by any model
TimoImhof Jul 30, 2024
0588db6
Add Future TODO for seq2seqtrainer
TimoImhof Aug 1, 2024
1751a25
Merge branch 'refs/heads/main' into dev/whisper
TimoImhof Aug 3, 2024
08911d8
Incorporate pyreft tests
TimoImhof Aug 3, 2024
88bd867
Add check for changing hidden_states size
TimoImhof Aug 3, 2024
3de2581
Adapt checking logic
TimoImhof Aug 4, 2024
89377f8
Merge branch 'refs/heads/main' into dev/whisper
TimoImhof Aug 4, 2024
5f4f20c
Fix attention classes and generation
TimoImhof Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/classes/models/whisper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Whisper
-----------------------------------------------------------------------------------------------------------------------

The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision
<https://arxiv.org/abs/2212.04356>`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine
McLeavey, Ilya Sutskever.

Whisper is a state-of-the-art speech recognition model trained on 680,000 hours of multilingual and multitask data, presented by OpenAI.

The abstract from the paper is the following:

*We study the capabilities of speech processing systems trained simply to predict large amounts of
transcripts of audio on the internet. When scaled to 680,000 hours of multilingual and multitask
supervision, the resulting models generalize well to standard benchmarks and are often competitive
with prior fully supervised results but in a zeroshot transfer setting without the need for any finetuning. When compared to humans, the models
approach their accuracy and robustness. We are releasing models and inference code to serve as
a foundation for further work on robust speech processing.*


WhisperAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.WhisperAdapterModel
:members:
:inherited-members: WhisperPreTrainedModel
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/roberta
classes/models/t5
classes/models/vit
classes/models/whisper
classes/models/xlmroberta
classes/models/xmod

Expand Down
2 changes: 2 additions & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```


| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning | ReFT |
| --------------------------------------- | -| - | - | - | - | - | - |- | - |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand All @@ -33,6 +34,7 @@ The table below further shows which model architectures support which adaptation
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Whisper](classes/models/whisper.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |

Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"models.roberta": ["RobertaAdapterModel"],
"models.t5": ["T5AdapterModel"],
"models.vit": ["ViTAdapterModel"],
"models.whisper": ["WhisperAdapterModel"],
"models.xlm_roberta": ["XLMRobertaAdapterModel"],
"models.xmod": ["XmodAdapterModel"],
"trainer": ["AdapterTrainer", "Seq2SeqAdapterTrainer"],
Expand Down Expand Up @@ -224,6 +225,7 @@
from .models.roberta import RobertaAdapterModel
from .models.t5 import T5AdapterModel
from .models.vit import ViTAdapterModel
from .models.whisper import WhisperAdapterModel
from .models.xlm_roberta import XLMRobertaAdapterModel
from .models.xmod import XmodAdapterModel
from .trainer import AdapterTrainer, Seq2SeqAdapterTrainer
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
"llama",
"mistral",
"electra",
"whisper",
"xmod",
],
}
Expand Down
11 changes: 10 additions & 1 deletion src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

logger = logging.getLogger(__name__)


# The "layers" attributes in the configs below map from static head module names to flex head module names.
# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts).
STATIC_TO_FLEX_HEAD_MAP = {
Expand Down Expand Up @@ -771,6 +770,16 @@
"generator_lm_head",
],
},
"WhisperForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
"layers": 1,
"activation_function": None,
"layer_norm": False,
"bias": False,
},
"layers": ["proj_out"],
},
}


Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
)
labels = torch.cat((prompt_labels, labels), dim=-1)

loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))
loss = loss_fct(logits_for_loss.reshape(-1, self.config["vocab_size"]), labels.reshape(-1))
calpt marked this conversation as resolved.
Show resolved Hide resolved

if return_dict:
return self._create_model_output(loss, lm_logits, outputs)
Expand Down
15 changes: 9 additions & 6 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

logger = logging.getLogger(__name__)


MODEL_HEAD_MAP = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down Expand Up @@ -440,47 +439,51 @@ def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=Fals
self.add_prediction_head(head, overwrite_ok)

@head_type("masked_lm")
def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False):
def add_masked_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False):
"""
Adds a masked language modeling head on top of the model.

Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
layers (int, optional): Number of layers. Defaults to 2.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function)
head = BertStyleMaskedLMHead(self, head_name, layers=layers, activation_function=activation_function)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

@head_type("causal_lm")
def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False):
def add_causal_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False):
"""
Adds a causal language modeling head on top of the model.

Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
layers (int, optional): Number of layers. Defaults to 2.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = CausalLMHead(
self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True
self, head_name, layers=layers, activation_function=activation_function, layer_norm=True, bias=True
)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

@head_type("seq2seq_lm")
def add_seq2seq_lm_head(
self,
head_name,
layers=1,
overwrite_ok=False,
):
"""
Adds a sequence-to-sequence language modeling head on top of the model.

Args:
head_name (str): The name of the head.
layers (int, optional): Number of layers. Defaults to 1.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = Seq2SeqLMHead(self, head_name)
head = Seq2SeqLMHead(self, head_name, layers=layers)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

def delete_head(self, head_name: str):
Expand Down
9 changes: 8 additions & 1 deletion src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,14 @@ def forward(self, *args, **kwargs):
prefix_states = {}
if adapter_setup is not None:
# Infer batch size
input_tensor_names = ["input_ids", "decoder_input_ids", "attention_mask", "inputs_embeds", "pixel_values"]
input_tensor_names = [
"input_ids",
"decoder_input_ids",
"attention_mask",
"inputs_embeds",
"pixel_values",
"input_features",
]
batch_size = None
for name in input_tensor_names:
if kwargs.get(name, None) is not None:
Expand Down
14 changes: 12 additions & 2 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,19 @@ def __init__(self, in_features: int, config: ReftConfig):

def _gather_adapted_states(self, hidden_states: torch.Tensor):
context = ForwardContext.get_context()
bsz, _, ddim = hidden_states.size()
bsz, seq_len, ddim = hidden_states.size()

# if cached indexing matrices are computed for different hidden_states size -> recompute
cache_invalidated = False
if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"):
cache_invalidated = (
torch.max(context.suff_idx) >= seq_len # indices out of bounds
or bsz != context.suff_idx.size(0) # batch size mismatch
or ddim != context.suff_idx.size(2) # hidden size mismatch
)

# no cached indexing matrices available -> compute now
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"):
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx") or cache_invalidated:
# read offsets & lengths from context
if hasattr(context, "seqlens"):
first_non_padding = context.offsets
Expand Down
13 changes: 12 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,18 @@ def _prepare_model_inputs(self, *args, **kwargs):
and self.adapters_config.active_setup
and self.adapters_config.active_setup.parallel_channels > 1
):
input_ids = input_ids.repeat(self.adapters_config.active_setup.parallel_channels, 1)
# Extract original shape
input_shape = input_ids.shape
# Replicate input_ids to match the number of parallel channels
# Also works for inputs with more than 2 dimensions
repeat_shape = [
self.adapters_config.active_setup.parallel_channels
] + [ # first dimension is parallel channels
1
] * (
len(input_shape) - 1
) # residual dims should be replicated parallel_channels times
input_ids = input_ids.repeat(repeat_shape)
model_kwargs["adapter_input_parallelized"] = True

return input_ids, input_name, model_kwargs
Expand Down
12 changes: 12 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
T5ModelAdaptersMixin,
)
from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin
from .whisper.mixin_whisper import (
WhisperDecoderAdaptersMixin,
WhisperDecoderWrapperAdaptersMixin,
WhisperEncoderAdaptersMixin,
WhisperForAudioClassificationWithHeadsMixin,
WhisperModelAdaptersMixin,
)
from .xmod.mixin_xmod import XmodModelAdaptersMixin


Expand Down Expand Up @@ -95,6 +102,11 @@
"BertGenerationEncoder": BertModelAdaptersMixin,
"BertGenerationLayer": BertLayerAdaptersMixin,
"LlamaModel": LlamaModelAdapterMixin,
"WhisperEncoder": WhisperEncoderAdaptersMixin,
"WhisperDecoder": WhisperDecoderAdaptersMixin,
"WhisperModel": WhisperModelAdaptersMixin,
"WhisperDecoderWrapper": WhisperDecoderWrapperAdaptersMixin,
"WhisperForAudioClassification": WhisperForAudioClassificationWithHeadsMixin,
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
"MistralModel": MistralModelAdapterMixin,
}
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
("roberta", "RobertaAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
("whisper", "WhisperAdapterModel"),
("xlm-roberta", "XLMRobertaAdapterModel"),
("xmod", "XmodAdapterModel"),
]
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
Expand Down
39 changes: 39 additions & 0 deletions src/adapters/models/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["WhisperAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import WhisperAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading
Loading