diff --git a/docs/classes/models/whisper.rst b/docs/classes/models/whisper.rst new file mode 100644 index 0000000000..8afdf24f24 --- /dev/null +++ b/docs/classes/models/whisper.rst @@ -0,0 +1,25 @@ +Whisper +----------------------------------------------------------------------------------------------------------------------- + +The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision +`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine +McLeavey, Ilya Sutskever. + +Whisper is a state-of-the-art speech recognition model trained on 680,000 hours of multilingual and multitask data, presented by OpenAI. + +The abstract from the paper is the following: + +*We study the capabilities of speech processing systems trained simply to predict large amounts of +transcripts of audio on the internet. When scaled to 680,000 hours of multilingual and multitask +supervision, the resulting models generalize well to standard benchmarks and are often competitive +with prior fully supervised results but in a zeroshot transfer setting without the need for any finetuning. When compared to humans, the models +approach their accuracy and robustness. We are releasing models and inference code to serve as +a foundation for further work on robust speech processing.* + + +WhisperAdapterModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.WhisperAdapterModel + :members: + :inherited-members: WhisperPreTrainedModel \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index c57bbb32dc..0c10de8a18 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -88,6 +88,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/roberta classes/models/t5 classes/models/vit + classes/models/whisper classes/models/xlmroberta classes/models/xmod diff --git a/docs/model_overview.md b/docs/model_overview.md index eb3c109cba..b364ab6ebb 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -10,6 +10,7 @@ The table below further shows which model architectures support which adaptation E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters. ``` + | Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | ReFT | | --------------------------------------- | -| - | - | - | - | - | - |- | - | | [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -33,6 +34,7 @@ The table below further shows which model architectures support which adaptation | [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Whisper](classes/models/whisper.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 20d8eaf77a..0864a9bc95 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -115,6 +115,7 @@ "models.roberta": ["RobertaAdapterModel"], "models.t5": ["T5AdapterModel"], "models.vit": ["ViTAdapterModel"], + "models.whisper": ["WhisperAdapterModel"], "models.xlm_roberta": ["XLMRobertaAdapterModel"], "models.xmod": ["XmodAdapterModel"], "trainer": ["AdapterTrainer", "Seq2SeqAdapterTrainer"], @@ -224,6 +225,7 @@ from .models.roberta import RobertaAdapterModel from .models.t5 import T5AdapterModel from .models.vit import ViTAdapterModel + from .models.whisper import WhisperAdapterModel from .models.xlm_roberta import XLMRobertaAdapterModel from .models.xmod import XmodAdapterModel from .trainer import AdapterTrainer, Seq2SeqAdapterTrainer diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 62c2854acc..48a6bc8acf 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -139,6 +139,7 @@ def __init__( "llama", "mistral", "electra", + "whisper", "xmod", ], } diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 8226d1ed6b..1e3e0760dd 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -5,7 +5,6 @@ logger = logging.getLogger(__name__) - # The "layers" attributes in the configs below map from static head module names to flex head module names. # In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts). STATIC_TO_FLEX_HEAD_MAP = { @@ -771,6 +770,16 @@ "generator_lm_head", ], }, + "WhisperForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + "layers": 1, + "activation_function": None, + "layer_norm": False, + "bias": False, + }, + "layers": ["proj_out"], + }, } diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py index d805d1fe6b..95953a4d16 100644 --- a/src/adapters/heads/language_modeling.py +++ b/src/adapters/heads/language_modeling.py @@ -131,7 +131,7 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal ) labels = torch.cat((prompt_labels, labels), dim=-1) - loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1)) + loss = loss_fct(logits_for_loss.reshape(-1, self.config["vocab_size"]), labels.reshape(-1)) if return_dict: return self._create_model_output(loss, lm_logits, outputs) diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 9a27bbd764..fa09dafa6e 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -27,7 +27,6 @@ logger = logging.getLogger(__name__) - MODEL_HEAD_MAP = { "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, @@ -440,30 +439,32 @@ def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=Fals self.add_prediction_head(head, overwrite_ok) @head_type("masked_lm") - def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + def add_masked_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False): """ Adds a masked language modeling head on top of the model. Args: head_name (str): The name of the head. activation_function (str, optional): Activation function. Defaults to 'gelu'. + layers (int, optional): Number of layers. Defaults to 2. overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. """ - head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function) + head = BertStyleMaskedLMHead(self, head_name, layers=layers, activation_function=activation_function) self.add_prediction_head(head, overwrite_ok=overwrite_ok) @head_type("causal_lm") - def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + def add_causal_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False): """ Adds a causal language modeling head on top of the model. Args: head_name (str): The name of the head. activation_function (str, optional): Activation function. Defaults to 'gelu'. + layers (int, optional): Number of layers. Defaults to 2. overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. """ head = CausalLMHead( - self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True + self, head_name, layers=layers, activation_function=activation_function, layer_norm=True, bias=True ) self.add_prediction_head(head, overwrite_ok=overwrite_ok) @@ -471,6 +472,7 @@ def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok def add_seq2seq_lm_head( self, head_name, + layers=1, overwrite_ok=False, ): """ @@ -478,9 +480,10 @@ def add_seq2seq_lm_head( Args: head_name (str): The name of the head. + layers (int, optional): Number of layers. Defaults to 1. overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. """ - head = Seq2SeqLMHead(self, head_name) + head = Seq2SeqLMHead(self, head_name, layers=layers) self.add_prediction_head(head, overwrite_ok=overwrite_ok) def delete_head(self, head_name: str): diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 1f7d4094bd..17ab177a45 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -235,7 +235,14 @@ def forward(self, *args, **kwargs): prefix_states = {} if adapter_setup is not None: # Infer batch size - input_tensor_names = ["input_ids", "decoder_input_ids", "attention_mask", "inputs_embeds", "pixel_values"] + input_tensor_names = [ + "input_ids", + "decoder_input_ids", + "attention_mask", + "inputs_embeds", + "pixel_values", + "input_features", + ] batch_size = None for name in input_tensor_names: if kwargs.get(name, None) is not None: diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index cd22fc57ba..0914e8d3aa 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -66,9 +66,19 @@ def __init__(self, in_features: int, config: ReftConfig): def _gather_adapted_states(self, hidden_states: torch.Tensor): context = ForwardContext.get_context() - bsz, _, ddim = hidden_states.size() + bsz, seq_len, ddim = hidden_states.size() + + # if cached indexing matrices are computed for different hidden_states size -> recompute + cache_invalidated = False + if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"): + cache_invalidated = ( + torch.max(context.suff_idx) >= seq_len # indices out of bounds + or bsz != context.suff_idx.size(0) # batch size mismatch + or ddim != context.suff_idx.size(2) # hidden size mismatch + ) + # no cached indexing matrices available -> compute now - if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"): + if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx") or cache_invalidated: # read offsets & lengths from context if hasattr(context, "seqlens"): first_non_padding = context.offsets diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 1802595819..659a6cfcff 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1412,7 +1412,18 @@ def _prepare_model_inputs(self, *args, **kwargs): and self.adapters_config.active_setup and self.adapters_config.active_setup.parallel_channels > 1 ): - input_ids = input_ids.repeat(self.adapters_config.active_setup.parallel_channels, 1) + # Extract original shape + input_shape = input_ids.shape + # Replicate input_ids to match the number of parallel channels + # Also works for inputs with more than 2 dimensions + repeat_shape = [ + self.adapters_config.active_setup.parallel_channels + ] + [ # first dimension is parallel channels + 1 + ] * ( + len(input_shape) - 1 + ) # residual dims should be replicated parallel_channels times + input_ids = input_ids.repeat(repeat_shape) model_kwargs["adapter_input_parallelized"] = True return input_ids, input_name, model_kwargs diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 8e759698db..77f569835d 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -33,6 +33,13 @@ T5ModelAdaptersMixin, ) from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin +from .whisper.mixin_whisper import ( + WhisperDecoderAdaptersMixin, + WhisperDecoderWrapperAdaptersMixin, + WhisperEncoderAdaptersMixin, + WhisperForAudioClassificationWithHeadsMixin, + WhisperModelAdaptersMixin, +) from .xmod.mixin_xmod import XmodModelAdaptersMixin @@ -95,6 +102,11 @@ "BertGenerationEncoder": BertModelAdaptersMixin, "BertGenerationLayer": BertLayerAdaptersMixin, "LlamaModel": LlamaModelAdapterMixin, + "WhisperEncoder": WhisperEncoderAdaptersMixin, + "WhisperDecoder": WhisperDecoderAdaptersMixin, + "WhisperModel": WhisperModelAdaptersMixin, + "WhisperDecoderWrapper": WhisperDecoderWrapperAdaptersMixin, + "WhisperForAudioClassification": WhisperForAudioClassificationWithHeadsMixin, "LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin, "MistralModel": MistralModelAdapterMixin, } diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 7ab5cd80fa..6711752054 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -29,6 +29,7 @@ ("roberta", "RobertaAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), + ("whisper", "WhisperAdapterModel"), ("xlm-roberta", "XLMRobertaAdapterModel"), ("xmod", "XmodAdapterModel"), ] diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py index b982d34d62..f09a56d7d1 100644 --- a/src/adapters/models/mt5/modeling_mt5.py +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -77,7 +77,7 @@ def forward( if past_key_value is not None: assert ( len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] diff --git a/src/adapters/models/whisper/__init__.py b/src/adapters/models/whisper/__init__.py new file mode 100644 index 0000000000..41b38a0756 --- /dev/null +++ b/src/adapters/models/whisper/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["WhisperAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import WhisperAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/whisper/adapter_model.py b/src/adapters/models/whisper/adapter_model.py new file mode 100644 index 0000000000..d76ae610c5 --- /dev/null +++ b/src/adapters/models/whisper/adapter_model.py @@ -0,0 +1,225 @@ +import torch + +from transformers import EncoderDecoderCache, StaticCache +from transformers.models.whisper.modeling_whisper import ( + WHISPER_INPUTS_DOCSTRING, + WHISPER_START_DOCSTRING, + WhisperConfig, + WhisperModel, + WhisperPreTrainedModel, + shift_tokens_right, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...heads import ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +@add_start_docstrings( + "WHISPER Model with the option to add multiple flexible prediction heads on top.", WHISPER_START_DOCSTRING +) +class WhisperAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, WhisperPreTrainedModel): + _tied_weights_keys = [] + head_types = ["seq2seq_lm"] + + def __init__(self, config: WhisperConfig, **kwargs): + super().__init__(config, **kwargs) + + self.model = WhisperModel(config) + init(self.model) + + self._init_head_modules() + + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. + """ + # Adapted from WhisperModel in transformers/models/whisper/modeling_whisper.py as the tests in + # test_modeling_whisper.py require the functionality to freeze the encoder + self.model.encoder._freeze_parameters() + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + def forward( + self, + input_features=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if "labels" in kwargs: + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + kwargs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id + ) + + # Future TODO: + # the seq2seqtrainer has the parameter `predict_with_generate` + # If set to True, we get the following error: + # transformers\generation\utils.py", line 1130, in _validate_model_kwargs"> + # raise ValueError(ValueError: The following model_kwargs are not used by the model: ['labels'] + # This is because we do not specify labels as parameter in the forward method + + outputs, context = self.model( + input_features=input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + + # required e.g. for prompt tuning in all models + kwargs["context"] = context + + head_outputs = self.forward_head( + outputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + get_cls_from_eos_tokens=True, + # `get_cls_from_eos_tokens` requires passing eos mask + eos_mask=input_features.eq(self.config.eos_token_id) if input_features is not None else None, + **kwargs, + ) + + return head_outputs + + # Copied from WhisperForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + decoder_attention_mask=None, + cache_position=None, + **kwargs, + ): + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + if decoder_position_ids is not None: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) + + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = decoder_input_ids.shape + device = decoder_input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, + # >>> START AH Changes <<< + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), + # >>> END AH Changes <<< + } + + # Copied from WhisperForConditionalGeneration + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past diff --git a/src/adapters/models/whisper/mixin_whisper.py b/src/adapters/models/whisper/mixin_whisper.py new file mode 100644 index 0000000000..e4a7d6da90 --- /dev/null +++ b/src/adapters/models/whisper/mixin_whisper.py @@ -0,0 +1,137 @@ +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn + +from ...composition import adjust_tensors_for_parallel +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer +from ...model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + InvertibleAdaptersWrapperMixin, + ModelBaseAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) +from ...utils import patch_forward + + +class WhisperAttentionAdaptersMixin: + """Adds adapters to the WhisperAttention module.""" + + def init_adapters(self, model_config, adapters_config): + # Wrap layers for LoRA + self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") + self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") + self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") + + self.prefix_tuning = PrefixTuningLayer( + self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config + ) + patch_forward(self) + + +class WhisperEncoderLayerAdaptersMixin: + """Adds adapters to the WhisperEncoderLayer module of WHISPER.""" + + def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA + self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) + self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) + + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "encoder" + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") + + patch_forward(self) + + +class WhisperDecoderLayerAdaptersMixin(WhisperEncoderLayerAdaptersMixin): + """Adds adapters to the WhisperDecoderLayer module of WHISPER.""" + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "self" + self.encoder_attn.location_key = "cross" + self.cross_attention_adapters = BottleneckLayer("cross_adapter") + + +class WhisperEncoderAdaptersMixin(InvertibleAdaptersMixin): + """Adds adapters to the WhisperEncoder module of WHISPER.""" + + pass + + +class WhisperDecoderAdaptersMixin: + """Adds adapters to the WhisperDecoder module of WHISPER.""" + + def init_adapters(self, model_config, adapters_config): + patch_forward(self) + + def forward( + self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs + ): + (input_ids,) = adjust_tensors_for_parallel(encoder_hidden_states, input_ids) + return super().forward(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs) + + +class WhisperModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the WhisperModel class.""" + + invertible_adapters_base_name = "encoder" + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config, **kwargs): + super().init_adapters(model_config, adapters_config, **kwargs) + self.encoder.layer_norm.register_forward_hook(self.post_embedding_forward) + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "decoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.encoder.layers): + yield i, layer + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + + +class WhisperDecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the WhisperDecoderWrapper class. + + This wrapper class is a helper class to correctly load + pretrained checkpoints when the causal language model is used in combination with the [`EncoderDecoderModel`] + framework.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + +class WhisperForAudioClassificationWithHeadsMixin(ModelWithHeadsAdaptersMixin, WhisperModelAdaptersMixin): + """Adds adapters to the WhisperForAudioClassification class. + This class is used to enable adapter capabilities for the static WhisperForAudioClassification model from the + 'transformers' library.""" + + def forward( + self, + *args, + input_features: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ): + # Stating "input_features" and "labels" explicitly is required for training using Trainer class + return super().forward(*args, input_features=input_features, labels=labels, **kwargs) diff --git a/src/adapters/models/whisper/modeling_whisper.py b/src/adapters/models/whisper/modeling_whisper.py new file mode 100644 index 0000000000..455dac0b58 --- /dev/null +++ b/src/adapters/models/whisper/modeling_whisper.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Whisper model.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + +from transformers import EncoderDecoderCache, StaticCache +from transformers.models.whisper.modeling_whisper import WhisperAttention, WhisperDecoderLayer, WhisperEncoderLayer +from transformers.utils import is_flash_attn_2_available, logging + +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel +from .mixin_whisper import ( + WhisperAttentionAdaptersMixin, + WhisperDecoderLayerAdaptersMixin, + WhisperEncoderLayerAdaptersMixin, +) + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class WhisperAttentionWithAdapters(WhisperAttention, WhisperAttentionAdaptersMixin): + # Copied from adapters/models/bart/modeling_bart.py + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + # Do self_attention + # -> use hidden_states as current_states for key and value multi-head projections + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # >>> START AH Changes <<< + # get query proj + query_states = self.q_proj(hidden_states) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + # if we are in a parallel setting we need to adjust the batch size + # when reshaping the query_states to multi-head format + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + # >>> END AH Changes <<< + + # Compute the attention weights + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Normalize the attention weights + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class WhisperFlashAttention2WithAdapters(WhisperAttentionAdaptersMixin, WhisperAttention): + # Copied from adapters/models/bart/modeling_bart.py + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) + # WhisperFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # >>> START AH Changes <<< + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + # >>> END AH Changes <<< + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class WhisperSdpaAttentionWithAdapters(WhisperAttentionAdaptersMixin, WhisperAttention): + # Copied from adapters/models/bart/modeling_bart.py + # and transformers/models/whisper/modeling_whisper.py + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # >>> START AH Changes <<< + # get query proj + query_states = self.q_proj(hidden_states) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + # if we are in a parallel setting we need to adjust the batch size + # when reshaping the query_states to multi-head format + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + # >>> END AH Changes <<< + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class WhisperEncoderLayerWithAdapters(WhisperEncoderLayer, WhisperEncoderLayerAdaptersMixin): + # Copied from adapters/models/mbart/modeling_mbart.py + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel_(hidden_states, attention_mask) + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, None) + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, None) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class WhisperDecoderLayerWithAdapters(WhisperDecoderLayer, WhisperDecoderLayerAdaptersMixin): + # Copied from adapters/models/mbart/modeling_mbart.py + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel_(hidden_states, attention_mask, encoder_attention_mask) + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, None) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.cross_attention_adapters(hidden_states, residual, None) + + # add cross-attn to positions 1 of present_key_value tuple + present_key_value = (present_key_value, cross_attn_present_key_value) + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, None) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py index ed224cd600..40dc421787 100644 --- a/src/adapters/wrappers/configuration.py +++ b/src/adapters/wrappers/configuration.py @@ -61,6 +61,12 @@ "attention_probs_dropout_prob": "dropout_rate", }, "vit": {}, + "whisper": { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "hidden_dropout_prob": "dropout", + "attention_probs_dropout_prob": "attention_dropout", + }, "xlm_roberta": {}, } SUBMODEL_NAMES = {"clip": ["vision_config", "text_config"], "encoder-decoder": ["encoder", "decoder"]} diff --git a/tests/composition/test_parallel.py b/tests/composition/test_parallel.py index 538dd79ca0..31dce09969 100644 --- a/tests/composition/test_parallel.py +++ b/tests/composition/test_parallel.py @@ -131,7 +131,10 @@ def test_parallel_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + if self.is_speech_model: + input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] + else: + input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (2, seq_output_length)) @@ -232,13 +235,21 @@ def run_parallel_training_equivalent_to_single(self, adapter_config): b1, b2 = self.create_twin_adapters(model, "b", adapter_config) dataset = [] - for i in range(3): - input_data = self.get_input_samples(config=model.config) - if isinstance(model, BertGenerationAdapterModel): - input_data["labels"] = torch.randint(0, 2, (3, 64)) - else: - input_data["labels"] = torch.randint(0, 2, (3, 1)) - dataset.append(input_data) + if self.is_speech_model: + dataset_batched = self.dataset() + dataset = [{} for _ in range(len(dataset_batched))] + # As this test uses a non-batched training, we need to wrap the samples by an additional dimension + for i in range(len(dataset_batched)): + for key, value in dataset_batched[i].items(): + dataset[i][key] = torch.unsqueeze(value, 0) + else: + for i in range(3): + input_data = self.get_input_samples(config=model.config) + if isinstance(model, BertGenerationAdapterModel): + input_data["labels"] = torch.randint(0, 2, (3, 64)) + else: + input_data["labels"] = torch.randint(0, 2, (3, 1)) + dataset.append(input_data) for adapter in [a1, b1]: model.active_head = adapter @@ -290,9 +301,13 @@ def test_parallel_training_single_forward_pass(self): if b1 in k: self.assertTrue(torch.equal(v, state_dict[k.replace(b1, b2)])) - input_data = self.get_input_samples(config=model.config) + input_data = self.get_input_samples( + config=model.config, + ) if isinstance(model, BertGenerationAdapterModel): input_data["labels"] = torch.randint(0, 2, (3, 64), device=torch_device) + elif self.is_speech_model: + input_data["labels"] = input_data["decoder_input_ids"] else: input_data["labels"] = torch.randint(0, 2, (3, 1), device=torch_device) diff --git a/tests/fixtures/audio_datasets/common_voice_encoded/dataset_dict.json b/tests/fixtures/audio_datasets/common_voice_encoded/dataset_dict.json new file mode 100644 index 0000000000..40d5c04b43 --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_encoded/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train"]} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/common_voice_encoded/train/data-00000-of-00001.arrow b/tests/fixtures/audio_datasets/common_voice_encoded/train/data-00000-of-00001.arrow new file mode 100644 index 0000000000..e25811b9b6 Binary files /dev/null and b/tests/fixtures/audio_datasets/common_voice_encoded/train/data-00000-of-00001.arrow differ diff --git a/tests/fixtures/audio_datasets/common_voice_encoded/train/dataset_info.json b/tests/fixtures/audio_datasets/common_voice_encoded/train/dataset_info.json new file mode 100644 index 0000000000..92f74d69d7 --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_encoded/train/dataset_info.json @@ -0,0 +1,25 @@ +{ + "citation": "", + "description": "", + "features": { + "input_features": { + "feature": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "_type": "Sequence" + }, + "_type": "Sequence" + }, + "labels": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/common_voice_encoded/train/state.json b/tests/fixtures/audio_datasets/common_voice_encoded/train/state.json new file mode 100644 index 0000000000..7900f451ae --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_encoded/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "138e80d328a5e394", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/common_voice_org/dataset_dict.json b/tests/fixtures/audio_datasets/common_voice_org/dataset_dict.json new file mode 100644 index 0000000000..40d5c04b43 --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_org/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train"]} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/common_voice_org/train/data-00000-of-00001.arrow b/tests/fixtures/audio_datasets/common_voice_org/train/data-00000-of-00001.arrow new file mode 100644 index 0000000000..5b6ef0827a Binary files /dev/null and b/tests/fixtures/audio_datasets/common_voice_org/train/data-00000-of-00001.arrow differ diff --git a/tests/fixtures/audio_datasets/common_voice_org/train/dataset_info.json b/tests/fixtures/audio_datasets/common_voice_org/train/dataset_info.json new file mode 100644 index 0000000000..a97cc2df61 --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_org/train/dataset_info.json @@ -0,0 +1,65 @@ +{ + "citation": "", + "description": "", + "features": { + "client_id": { + "dtype": "string", + "_type": "Value" + }, + "path": { + "dtype": "string", + "_type": "Value" + }, + "audio": { + "array": { + "feature": { + "dtype": "float64", + "_type": "Value" + }, + "_type": "Sequence" + }, + "path": { + "dtype": "string", + "_type": "Value" + }, + "sampling_rate": { + "dtype": "int64", + "_type": "Value" + } + }, + "sentence": { + "dtype": "string", + "_type": "Value" + }, + "up_votes": { + "dtype": "int64", + "_type": "Value" + }, + "down_votes": { + "dtype": "int64", + "_type": "Value" + }, + "age": { + "dtype": "string", + "_type": "Value" + }, + "gender": { + "dtype": "string", + "_type": "Value" + }, + "accent": { + "dtype": "string", + "_type": "Value" + }, + "locale": { + "dtype": "string", + "_type": "Value" + }, + "segment": { + "dtype": "string", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/common_voice_org/train/state.json b/tests/fixtures/audio_datasets/common_voice_org/train/state.json new file mode 100644 index 0000000000..0b73b1cef1 --- /dev/null +++ b/tests/fixtures/audio_datasets/common_voice_org/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "4db5194a70f28e75", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/prepare_audio_datasets.py b/tests/fixtures/audio_datasets/prepare_audio_datasets.py new file mode 100644 index 0000000000..990de64b05 --- /dev/null +++ b/tests/fixtures/audio_datasets/prepare_audio_datasets.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Union + +import torch +from datasets import Audio, Dataset, DatasetDict, load_dataset, load_from_disk + +from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer + + +def create_common_voice(): + """Creates a small abstract dataset of 10 samples from the common voice dataset in english.""" + common_voice = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="validation", streaming=True) + common_voice = iter(common_voice) + + rows = [] + for i, sample in enumerate(common_voice): + rows.append(sample) + if i == 9: + break + + dataset_dict = DatasetDict({"train": Dataset.from_list(rows)}) + dataset_dict.save_to_disk("common_voice_org") + return dataset_dict + + +def create_common_voice_encoded(dataset_path="common_voice_org"): + """Preprocesses the common voice dataset and creates a new encoded version ready for training.""" + model_id = "openai/whisper-tiny" + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + sampling_rate = feature_extractor.sampling_rate + decoder_start_token_id = 50257 + + # Preprocessing adapted from this example notebook: + # https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb + + def _prepare_dataset(batch): + # load and resample audio data from 48 to 16kHz + audio = batch["audio"] + + # compute log-Mel input features from input audio array + batch["input_features"] = feature_extractor( + audio["array"], sampling_rate=audio["sampling_rate"] + ).input_features[0] + + # encode target text to label ids + batch["labels"] = tokenizer(batch["sentence"]).input_ids + return batch + + def _collate_dataset_with_padding( + features: List[Dict[str, Union[List[int], torch.Tensor]]], processor, decoder_start_token_id: int + ) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lengths and need different padding methods + # first treat the audio inputs by simply returning torch tensors + input_features = [{"input_features": feature} for feature in features["input_features"]] + batch = processor.feature_extractor.pad(input_features, return_tensors="pt") + + # get the tokenized label sequences + label_features = [{"input_ids": feature} for feature in features["labels"]] + # pad the labels to max length + labels_batch = processor.tokenizer.pad(label_features, return_tensors="pt") + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + return batch + + dataset = load_from_disk(dataset_path) + dataset = dataset.remove_columns( + ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"] + ) + dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate)) + + dataset = dataset.map(_prepare_dataset, remove_columns=dataset.column_names["train"]) + dataset = dataset.map( + lambda x: _collate_dataset_with_padding(x, processor, decoder_start_token_id), + batched=True, + batch_size=10, + ) + + dataset.set_format(type="torch") + dataset.save_to_disk("common_voice_encoded") + + +def create_speech_commands(): + """Creates a small abstract dataset of 10 samples from the speech commands dataset.""" + dataset = load_dataset("speech_commands", "v0.02", streaming=True, split="validation") + labels = [1, 2] + + rows = [] + for i, sample in enumerate(dataset): + # Assign one of the labels to the sample + sample["label"] = [labels[i % len(labels)]] + rows.append(sample) + if i == 9: + break + + dataset_dict = DatasetDict({"train": Dataset.from_list(rows)}) + dataset_dict.save_to_disk("speech_commands_org") + return dataset_dict + + +def create_speech_commands_encoded(dataset_path="speech_commands_org"): + """Preprocesses the speech commands dataset and creates a new encoded version ready for training.""" + dataset = load_from_disk(dataset_path) + dataset = dataset.select_columns(["audio", "label"]) + + # Preprocessing copied and adapted from: + # https://colab.research.google.com/drive/1nU6dlYamT32kfLe2t_AytmOPRjaOxOZn?usp=sharing#scrollTo=GF93pim6eo9e + + model_id = "openai/whisper-tiny" + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, do_normalize=True) + + sampling_rate = feature_extractor.sampling_rate + dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate)) + + max_duration = 30 + + def preprocess_function(examples): + audio_arrays = [x["array"] for x in examples["audio"]] + inputs = feature_extractor( + audio_arrays, + sampling_rate=feature_extractor.sampling_rate, + max_length=int(feature_extractor.sampling_rate * max_duration), + truncation=True, + ) + return inputs + + dataset_encoded = dataset.map( + preprocess_function, + remove_columns="audio", + batched=True, + batch_size=2, + num_proc=1, + ) + # convert to torch format + dataset_encoded.set_format(type="torch") + dataset_encoded.save_to_disk("speech_commands_encoded") + return dataset_encoded + + +if __name__ == "__main__": + + create_seq2seq = False + create_classification = False + + if create_seq2seq: + # Create and preprocess sequence classification dataset + create_common_voice() + create_common_voice_encoded() + + # Load and inspect the dataset + dataset = load_from_disk("common_voice_encoded") + for sample in dataset["train"]: + print(sample.keys()) + print(sample["input_features"].shape) + print(sample["labels"].shape) + break + + if create_classification: + # Create and preprocess audio classification dataset + create_speech_commands() + create_speech_commands_encoded() + + # Load and inspect the dataset + dataset = load_from_disk("speech_commands_encoded") + for sample in dataset["train"]: + print(sample.keys()) + print(sample["input_features"].shape) + print(sample["label"]) + break diff --git a/tests/fixtures/audio_datasets/speech_commands_org/dataset_dict.json b/tests/fixtures/audio_datasets/speech_commands_org/dataset_dict.json new file mode 100644 index 0000000000..40d5c04b43 --- /dev/null +++ b/tests/fixtures/audio_datasets/speech_commands_org/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train"]} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/speech_commands_org/train/data-00000-of-00001.arrow b/tests/fixtures/audio_datasets/speech_commands_org/train/data-00000-of-00001.arrow new file mode 100644 index 0000000000..4643a43a2f Binary files /dev/null and b/tests/fixtures/audio_datasets/speech_commands_org/train/data-00000-of-00001.arrow differ diff --git a/tests/fixtures/audio_datasets/speech_commands_org/train/dataset_info.json b/tests/fixtures/audio_datasets/speech_commands_org/train/dataset_info.json new file mode 100644 index 0000000000..2d1e00e799 --- /dev/null +++ b/tests/fixtures/audio_datasets/speech_commands_org/train/dataset_info.json @@ -0,0 +1,48 @@ +{ + "citation": "", + "description": "", + "features": { + "file": { + "dtype": "string", + "_type": "Value" + }, + "audio": { + "array": { + "feature": { + "dtype": "float64", + "_type": "Value" + }, + "_type": "Sequence" + }, + "path": { + "dtype": "string", + "_type": "Value" + }, + "sampling_rate": { + "dtype": "int64", + "_type": "Value" + } + }, + "label": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "Sequence" + }, + "is_unknown": { + "dtype": "bool", + "_type": "Value" + }, + "speaker_id": { + "dtype": "string", + "_type": "Value" + }, + "utterance_id": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/fixtures/audio_datasets/speech_commands_org/train/state.json b/tests/fixtures/audio_datasets/speech_commands_org/train/state.json new file mode 100644 index 0000000000..6363acc305 --- /dev/null +++ b/tests/fixtures/audio_datasets/speech_commands_org/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "c0cc06d19bea0105", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/methods/base.py b/tests/methods/base.py index 6ede68f2f3..0d20f32fef 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -317,6 +317,8 @@ def has_tied_embeddings(k): return tied_embeddings and is_tied_layer for (k1, v1), (k2, v2) in zip(state_dict_pre.items(), model.state_dict().items()): + # move both to the same device to avoid device mismatch errors + v1, v2 = v1.to(v2.device), v2 if "mrpc" in k1 and not has_tied_embeddings(k1): adapters_with_change |= not torch.equal(v1, v2) else: diff --git a/tests/methods/test_compacter.py b/tests/methods/test_compacter.py index 75716ffa61..292fab1efb 100644 --- a/tests/methods/test_compacter.py +++ b/tests/methods/test_compacter.py @@ -71,7 +71,10 @@ def test_compacter_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + if self.is_speech_model: + input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] + else: + input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index 2b351d0fcc..dd443c0d0b 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -94,7 +94,10 @@ def test_prefix_tuning_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + if self.is_speech_model: + input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] + else: + input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py new file mode 100644 index 0000000000..bfeea5a508 --- /dev/null +++ b/tests/models/test_whisper.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import WhisperAdapterModel +from hf_transformers.tests.models.whisper.test_modeling_whisper import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class WhisperAdapterModelTest(AdapterModelTesterMixin, WhisperModelTest): + all_model_classes = (WhisperAdapterModel,) + fx_compatible = False diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 6414a18756..bafa7e65a9 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -39,6 +39,12 @@ class AdapterTestBase: default_input_samples_shape = (3, 64) leave_out_layers = [0, 1] do_run_train_tests = True + # default arguments for test_adapter_heads + batch_size = 1 + seq_length = 128 + is_speech_model = ( + False # Flag for tests to determine if the model is a speech model due to input format difference + ) def get_model(self): if self.model_class == AutoAdapterModel: @@ -134,3 +140,71 @@ def transform(example_batch): dataset = dataset.with_transform(transform) return dataset + + +class SpeechAdapterTestBase(AdapterTestBase): + """Base class for speech adapter tests.""" + + default_input_samples_shape = (3, 80, 3000) # (batch_size, n_mels, enc_seq_len) + is_speech_model = True # Flag for tests to determine if the model is a speech model due to input format difference + time_window = 3000 # Time window for audio samples + seq_length = 80 + + def add_head(self, model, name, head_type="seq2seq_lm", **kwargs): + """Adds a head to the model.""" + if head_type == "audio_classification": + model.add_audio_classification_head(name, **kwargs) + return model.heads[name].config["num_labels"] + elif head_type == "seq2seq_lm": + kwargs.pop("num_labels", 1) # Remove num_labels from kwargs if present in the tests + model.add_seq2seq_lm_head(name, **kwargs) + return self.default_input_samples_shape[1] # Return the number of mel features + else: + raise ValueError(f"Head type {head_type} not supported.") + + def get_input_samples(self, shape=None, config=None, **kwargs): + """Creates a dummy batch of samples in the format required for speech models.""" + shape = shape or self.default_input_samples_shape + + # Input features + total_dims = 1 + for dim in shape: + total_dims *= dim + values = [] + for _ in range(total_dims): + values.append(random.random()) + input_features = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() + in_data = {"input_features": input_features} + + # Decoder input ids + if config and config.is_encoder_decoder: + in_data["decoder_input_ids"] = ids_tensor((shape[:-1]), config.vocab_size) + return in_data + + _TASK_DATASET_MAPPING = { + "seq2seq_lm": "./tests/fixtures/audio_datasets/common_voice_encoded", + "audio_classification": "./tests/fixtures/audio_datasets/speech_commands_encoded", + } + + def dataset(self, feature_extractor=None, processor=None, tokenizer=None, task_type: str = "seq2seq_lm", **kwargs): + """Returns a dataset to test speech model training. Standard dataset is for seq2seq_lm.""" + if task_type == "seq2seq_lm": + return self._prep_seq2seq_lm_dataset(task_type, **kwargs) + elif task_type == "audio_classification": + return self._prep_audio_classification_dataset(task_type, **kwargs) + + def _prep_seq2seq_lm_dataset(self, task_type, **kwargs): + """Prepares a dataset for conditional generation.""" + # The dataset is already processed and saved to disk, to save time during testing + # Preparation script can be found in tests/fixtures/audio_datasets/prepare_audio_datasets.py + dataset_path = self._TASK_DATASET_MAPPING[task_type] + dataset = datasets.load_from_disk(dataset_path) + return dataset["train"] + + def _prep_audio_classification_dataset(self, task_type, **kwargs): + """Prepares a dataset for audio classification.""" + # The dataset is already processed and saved to disk, to save time during testing + # Preparation script can be found in tests/fixtures/audio_datasets/prepare_audio_datasets.py + dataset_path = self._TASK_DATASET_MAPPING[task_type] + dataset = datasets.load_from_disk(dataset_path) + return dataset["train"] diff --git a/tests/test_adapter_conversion.py b/tests/test_adapter_conversion.py index df209c12ba..9653b3f340 100644 --- a/tests/test_adapter_conversion.py +++ b/tests/test_adapter_conversion.py @@ -14,6 +14,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, AlbertPreTrainedModel, BertPreTrainedModel, @@ -26,10 +27,6 @@ @require_torch class ModelClassConversionTestMixin: - - batch_size = 1 - seq_length = 128 - def run_test(self, static_model, input_shape=None, label_dict=None): flex_model = AutoAdapterModel.from_pretrained(None, config=self.config(), state_dict=static_model.state_dict()) static_model.eval() @@ -107,14 +104,30 @@ def test_conversion_masked_lm_model(self): self.run_test(model, label_dict=label_dict) def test_conversion_seq2seq_lm_model(self): - if self.config_class not in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: + if ( + self.config_class not in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + and self.config_class not in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + ): self.skipTest("No seq2seq language modeling class.") - model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config()) - adapters.init(model) label_dict = {} - label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) - label_dict["decoder_input_ids"] = label_dict["labels"].clone() + if self.is_speech_model: + # speech models require input_features + model = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING[self.config_class](self.config()) + label_dict["input_features"] = torch.randn( + (self.default_input_samples_shape), dtype=torch.float32, device=torch_device + ) + label_dict["decoder_input_ids"] = torch.randint( + 0, model.config.vocab_size, size=self.default_input_samples_shape[:-1], device=torch_device + ) + label_dict["labels"] = label_dict["decoder_input_ids"] + else: + model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config()) + label_dict["labels"] = torch.zeros( + (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device + ) + label_dict["decoder_input_ids"] = label_dict["labels"].clone() + adapters.init(model) self.run_test(model, label_dict=label_dict) def test_conversion_classification_model(self): diff --git a/tests/test_adapter_embeddings.py b/tests/test_adapter_embeddings.py index 393331c33d..160828c776 100644 --- a/tests/test_adapter_embeddings.py +++ b/tests/test_adapter_embeddings.py @@ -47,8 +47,7 @@ def test_delete_embeddings(self): def test_save_load_embedding(self): model = self.get_model() - tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") - input_data = self.get_input_samples((1, 128), vocab_size=tokenizer.vocab_size, config=model.config) + tokenizer, input_data = self._instantiate_tokenizer(model) model.add_embeddings("test", tokenizer) model.eval() model.to(torch_device) @@ -71,9 +70,8 @@ def test_save_load_embedding(self): def test_back_to_default(self): model = self.get_model() model.eval() - input_data = self.get_input_samples((1, 128), config=model.config) + tokenizer, input_data = self._instantiate_tokenizer(model) output1 = model(**input_data) - tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") model.add_embeddings("test", tokenizer) self.assertEqual(model.active_embeddings, "test") model.set_active_embeddings("default") @@ -176,3 +174,14 @@ def test_reference_embedding(self): # activate for training model.add_adapter("test") model.train_adapter("test", train_embeddings=True) + + def _instantiate_tokenizer(self, model): + """Depending on the model type, instantiate a tokenizer and input data. + Speech models require a different tokenizer and sample size.""" + if self.is_speech_model: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + input_data = self.get_input_samples(config=self.config()) + else: + tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") + input_data = self.get_input_samples((1, 128), vocab_size=tokenizer.vocab_size, config=model.config) + return tokenizer, input_data diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index 0de9134c05..541debf35b 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -15,12 +15,16 @@ @require_torch class PredictionHeadModelTestMixin: - - batch_size = 1 - seq_length = 128 - def run_prediction_head_test( - self, model, compare_model, head_name, input_shape=None, output_shape=(1, 2), label_dict=None + self, + model, + compare_model, + head_name, + input_shape=None, + output_shape=(1, 2), + label_dict=None, + num_labels=None, + with_labels=False, ): # first, check if the head is actually correctly registered as part of the pt module self.assertTrue(f"heads.{head_name}" in dict(model.named_modules())) @@ -39,8 +43,10 @@ def run_prediction_head_test( # make a forward pass model.active_head = head_name - input_shape = input_shape or (self.batch_size, self.seq_length) - in_data = self.get_input_samples(input_shape, config=model.config) + input_shape = input_shape if input_shape is not None else self._get_input_shape() + in_data = self.get_input_samples( + input_shape, config=model.config, num_labels=num_labels, with_labels=with_labels + ) if label_dict: for k, v in label_dict.items(): in_data[k] = v @@ -168,7 +174,11 @@ def test_seq2seq_lm_head(self): ) # Finally, also check if generation works properly - input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"] + input_shape = self._get_input_shape() + if self.is_speech_model: + input_ids = self.get_input_samples(input_shape, config=model1.config)["input_features"] + else: + input_ids = self.get_input_samples(input_shape, config=model1.config)["input_ids"] input_ids = input_ids.to(torch_device) # Use a different length for the seq2seq output seq_output_length = self.seq_length + 30 @@ -249,7 +259,7 @@ def test_delete_head(self): self.assertNotEqual(name, model.active_head) def test_adapter_with_head(self): - model1, model2 = create_twin_models(AutoAdapterModel, self.config) + model1, model2 = create_twin_models(self.model_class, self.config) name = "dummy" model1.add_adapter(name) @@ -271,7 +281,7 @@ def test_adapter_with_head(self): self.assertEqual(output_size, output1[0].size()[1]) def test_adapter_with_head_load_as(self): - model1, model2 = create_twin_models(AutoAdapterModel, self.config) + model1, model2 = create_twin_models(self.model_class, self.config) name = "dummy" model1.add_adapter(name) @@ -408,7 +418,8 @@ def forward_pre_hook(module, input): self.assertIsNotNone(inv_adapter) inv_adapter.register_forward_pre_hook(forward_pre_hook) - in_data = self.get_input_samples((self.batch_size, self.seq_length), config=model.config) + input_shape = self._get_input_shape() + in_data = self.get_input_samples(input_shape, config=model.config) model.to(torch_device) out = model(**in_data) @@ -457,6 +468,14 @@ def test_save_all_adapters_with_head(self): model.save_all_adapters(tmp_dir, with_head=False) self.assertFalse(os.path.isfile(os.path.join(tmp_dir, "test", "head_config.json"))) + def _get_input_shape(self): + # speech models require a different input dimensions compared to text models + if self.is_speech_model: + input_shape = (self.batch_size, self.seq_length, self.time_window) + else: + input_shape = (self.batch_size, self.seq_length) + return input_shape + def test_average_head(self): # Test the average_head method model = AutoAdapterModel.from_config(self.config()) diff --git a/tests/test_whisper.py b/tests/test_whisper.py new file mode 100644 index 0000000000..c3cd3d2206 --- /dev/null +++ b/tests/test_whisper.py @@ -0,0 +1,71 @@ +import unittest + +from tests.methods.test_config_union import ConfigUnionAdapterTest +from transformers import WhisperConfig +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + ReftTestMixin, + UniPELTTestMixin, +) +from .test_adapter import SpeechAdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class WhisperAdapterTestBase(SpeechAdapterTestBase): + config_class = WhisperConfig + config = make_config( + WhisperConfig, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=51865, + ) + tokenizer_name = "openai/whisper-small" + sampling_rate = 16000 + decoder_start_token_id = 50257 + + +@require_torch +class WhisperAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + ReftTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + EmbeddingTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + ConfigUnionAdapterTest, + WhisperAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class WhisperClassConversionTest( + ModelClassConversionTestMixin, + WhisperAdapterTestBase, + unittest.TestCase, +): + pass