From 5065d27c0219f2d36d13891e0ccee4cda1c10881 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 16 Sep 2023 16:29:20 +0200 Subject: [PATCH] Use seq. classification head in T5 tests. Move used heads retrieval to new method. --- src/adapters/heads/base.py | 43 +++++++++++-------- src/adapters/models/t5/adapter_model.py | 32 +++++++------- src/adapters/models/t5/modeling_t5.py | 18 +++++--- tests_adapters/composition/test_parallel.py | 4 +- tests_adapters/methods/test_adapter_common.py | 4 +- tests_adapters/methods/test_prefix_tuning.py | 3 +- tests_adapters/test_t5.py | 43 +------------------ 7 files changed, 61 insertions(+), 86 deletions(-) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index 2a097c74ad..dd43a4e658 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -730,6 +730,27 @@ def delete_head(self, head_name: str): if self.active_head == head_name: self.active_head = None + def _get_used_heads(self, head_name: str = None): + if head_name: + used_heads = [head_name] + # together with context, check if we have heads at all to allow for models without heads + elif len(self.heads) > 0 and AdapterSetup.get_context_head_setup(): + used_heads = AdapterSetup.get_context_head_setup() + if isinstance(used_heads, str): + used_heads = [used_heads] + elif self._active_heads: + used_heads = self._active_heads + else: + return [] + + head_modules = [] + for head in used_heads: + if head not in self.heads: + raise ValueError("Unknown head_name '{}'".format(head)) + head_modules.append(self.heads[head]) + + return head_modules + def forward_head( self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs ): @@ -750,16 +771,8 @@ def forward_head( return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple. **kwargs: Additional keyword arguments passed to the forward pass of the head. """ - if head_name: - used_heads = [head_name] - # together with context, check if we have heads at all to allow for models without heads - elif len(self.heads) > 0 and AdapterSetup.get_context_head_setup(): - used_heads = AdapterSetup.get_context_head_setup() - if isinstance(used_heads, str): - used_heads = [used_heads] - elif self._active_heads: - used_heads = self._active_heads - else: + used_head_modules = self._get_used_heads(head_name) + if len(used_head_modules) == 0: logger.debug("No prediction head is used.") return all_outputs @@ -787,9 +800,6 @@ def _get_head_input(outputs, cls_out, batch): if inv_adapter: kwargs["invertible_adapter"] = inv_adapter - for head in used_heads: - if head not in self.heads: - raise ValueError("Unknown head_name '{}'".format(head)) if isinstance(self.active_head, BatchSplit): if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]: raise ValueError( @@ -830,14 +840,13 @@ def _get_head_input(outputs, cls_out, batch): else None ) return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) - elif len(used_heads) > 1: + elif len(used_head_modules) > 1: head_outputs = [] - for head in used_heads: - head_module = self.heads[head] + for head_module in used_head_modules: head_outputs.append(head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)) return_output = MultiHeadOutput(head_outputs=head_outputs) else: - head_module = self.heads[used_heads[0]] + head_module = used_head_modules[0] return_output = head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs) if isinstance(return_output, ModelOutput): diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 5522748291..66441727c7 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -73,9 +73,14 @@ def forward( **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) + if decoder_input_ids is None and decoder_inputs_embeds is None: + # Check if we're using a LM head + if labels is not None and any([isinstance(head, Seq2SeqLMHead) for head in self._get_used_heads(head)]): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + else: + # decoder_input_ids from input_ids if no decoder_input_ids are provided + decoder_input_ids = self._shift_right(input_ids) model_output = self.transformer( input_ids=input_ids, @@ -121,18 +126,15 @@ def forward( else: cls_representation = sequence_output - if head or self.active_head: - kwargs["labels"] = labels - head_outputs = self.forward_head( - model_output, - head_name=head, - cls_output=cls_representation, - return_dict=return_dict, - **kwargs, - ) - return head_outputs - else: - return model_output + kwargs["labels"] = labels + head_outputs = self.forward_head( + model_output, + head_name=head, + cls_output=cls_representation, + return_dict=return_dict, + **kwargs, + ) + return head_outputs # Copied from T5ForConditionalGeneration def prepare_inputs_for_generation( diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 820c092027..7d7e467f0a 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -292,7 +292,8 @@ def forward( raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") if inputs_embeds is None: - assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape @@ -301,7 +302,8 @@ def forward( mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length if use_cache is True: - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) @@ -330,6 +332,13 @@ def forward( else: encoder_extended_attention_mask = None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) @@ -369,11 +378,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/tests_adapters/composition/test_parallel.py b/tests_adapters/composition/test_parallel.py index c33d5e362f..56ea422308 100644 --- a/tests_adapters/composition/test_parallel.py +++ b/tests_adapters/composition/test_parallel.py @@ -234,7 +234,7 @@ def run_parallel_training_equivalent_to_single(self, adapter_config): dataset = [] for i in range(3): input_data = self.get_input_samples(config=model.config) - if isinstance(model, T5AdapterModel) or isinstance(model, BertGenerationAdapterModel): + if isinstance(model, BertGenerationAdapterModel): input_data["labels"] = torch.randint(0, 2, (3, 64)) else: input_data["labels"] = torch.randint(0, 2, (3, 1)) @@ -291,7 +291,7 @@ def test_parallel_training_single_forward_pass(self): self.assertTrue(torch.equal(v, state_dict[k.replace(b1, b2)])) input_data = self.get_input_samples(config=model.config) - if isinstance(model, T5AdapterModel) or isinstance(model, BertGenerationAdapterModel): + if isinstance(model, BertGenerationAdapterModel): input_data["labels"] = torch.randint(0, 2, (3, 64), device=torch_device) else: input_data["labels"] = torch.randint(0, 2, (3, 1), device=torch_device) diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py index 81033924a0..616e6a99e8 100644 --- a/tests_adapters/methods/test_adapter_common.py +++ b/tests_adapters/methods/test_adapter_common.py @@ -19,7 +19,7 @@ SeqBnInvConfig, ) from adapters.heads.language_modeling import CausalLMHead -from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING +from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, CLIPConfig from transformers.testing_utils import require_torch, torch_device from .base import AdapterMethodBaseTestMixin, create_twin_models @@ -148,7 +148,7 @@ def test_get_adapter(self): n_layers = len(list(model.iter_layers())) if model.config.is_encoder_decoder: n_prefix_layers = 3 - elif model.config.is_composition: + elif model.config.is_composition or isinstance(model.config, CLIPConfig): n_prefix_layers = 2 else: n_prefix_layers = 1 diff --git a/tests_adapters/methods/test_prefix_tuning.py b/tests_adapters/methods/test_prefix_tuning.py index f08a9a492f..798f4b19d4 100644 --- a/tests_adapters/methods/test_prefix_tuning.py +++ b/tests_adapters/methods/test_prefix_tuning.py @@ -1,6 +1,7 @@ import torch from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig +from transformers import CLIPConfig from transformers.testing_utils import require_torch, torch_device from .base import AdapterMethodBaseTestMixin @@ -24,7 +25,7 @@ def test_get_prefix_tuning(self): model = self.get_model() if model.config.is_encoder_decoder: n_prefix_layers = 3 - elif model.config.is_composition: + elif model.config.is_composition or isinstance(model.config, CLIPConfig): n_prefix_layers = 2 else: n_prefix_layers = 1 diff --git a/tests_adapters/test_t5.py b/tests_adapters/test_t5.py index 7061f68ed3..c8717d8b54 100644 --- a/tests_adapters/test_t5.py +++ b/tests_adapters/test_t5.py @@ -1,8 +1,6 @@ import unittest -from datasets import load_dataset - -from transformers import AutoTokenizer, T5Config +from transformers import T5Config from transformers.testing_utils import require_torch from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -38,45 +36,6 @@ class T5AdapterTestBase(AdapterTestBase): ) tokenizer_name = "t5-base" - def dataset(self, tokenizer=None): - # setup tokenizer - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - def preprocess_function(examples): - inputs = examples["document"] - targets = examples["summary"] - inputs = ["Summarize: " + inp for inp in inputs] - model_inputs = tokenizer(inputs, padding=True, truncation=True) - - # Setup the tokenizer for targets - with tokenizer.as_target_tokenizer(): - labels = tokenizer(targets, padding=True, truncation=True) - - # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore - # padding in the loss. - labels["input_ids"] = [ - [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] - ] - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - data_args = { - "task_name": "xsum", - "path": "./hf_transformers/tests/fixtures/tests_samples/xsum/sample.json", - } - dataset = load_dataset("json", data_files=data_args["path"]) - train_dataset = dataset["train"] - train_dataset = train_dataset.map( - preprocess_function, - batched=True, - desc="Running tokenizer on train dataset", - ) - return train_dataset - @require_torch class T5AdapterTest(