From 1c9ecca7b5a10420929cdc36dc134e43cc018df1 Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 21 Dec 2023 13:16:41 +0100 Subject: [PATCH] [Bart] Move CLS rep extraction from EOS tokens to head classes --- src/adapters/heads/base.py | 48 +++++++++++++++-------- src/adapters/models/bart/adapter_model.py | 15 ++----- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index d45897df10..fc4f6990b7 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -18,7 +18,13 @@ ) from transformers.utils import ModelOutput -from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition +from ..composition import ( + AdapterCompositionBlock, + BatchSplit, + Parallel, + adjust_tensors_for_parallel, + parse_heads_from_composition, +) from ..context import AdapterSetup, ForwardContext from ..loading import PredictionHeadLoader from ..methods.modeling import Activation_Function_Class @@ -105,6 +111,21 @@ def get_output_embeddings(self): def get_label_names(self): return ["labels"] + def _get_cls_output(self, outputs, **kwargs): + if self.config["use_pooler"]: + cls_output = kwargs.pop("pooled_output") + elif kwargs.get("get_cls_from_eos_tokens", False): + x = outputs[0] # last hidden state + eos_mask = kwargs.get("eos_mask") + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_output = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + else: + cls_output = outputs[0][:, 0] + + return cls_output + class ClassificationHead(PredictionHead): def __init__( @@ -134,10 +155,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -205,10 +223,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -271,10 +286,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=None, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) logits = logits.view(-1, self.config["num_choices"]) loss = None @@ -476,10 +488,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -800,6 +809,9 @@ def forward_head( cls_output (torch.Tensor, optional): The classification output of the model. attention_mask (torch.Tensor, optional): The attention mask of the model. return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple. + get_cls_from_eos_tokens (bool): + If set to True, retrieve classifier token representations from the last token in the sequence. + Setting to True requires `eos_mask` to be passed as well. **kwargs: Additional keyword arguments passed to the forward pass of the head. """ used_head_modules = self._get_used_heads(head_name) @@ -846,10 +858,12 @@ def _get_head_input(outputs, cls_out, batch): ) head_outputs = [] labels = kwargs.pop("labels", None) + eos_mask = kwargs.pop("eos_mask", None) for i, head in enumerate(self.active_head): head_module = self.heads[head] batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1])) kwargs["labels"] = labels[batch_idx] if labels is not None else None + kwargs["eos_mask"] = eos_mask[batch_idx] if eos_mask is not None else None head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx) # head_attention = attention_mask[batch_idx] if attention_mask is not None else None head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs) diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index ddb94e6fe9..ad75324fd1 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -10,7 +10,6 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...composition import adjust_tensors_for_parallel from ...heads import ( ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, @@ -102,23 +101,15 @@ def forward( ) # required e.g. for prompt tuning in all models kwargs["context"] = context - # sequence classification based on last token in sequence - x = outputs[0] # last hidden state - if input_ids is not None and x.shape[1] == input_ids.shape[1]: - eos_mask = input_ids.eq(self.config.eos_token_id) - (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) - if len(torch.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] - else: - cls_representation = x head_outputs = self.forward_head( outputs, head_name=head, - cls_output=cls_representation, attention_mask=attention_mask, return_dict=return_dict, + get_cls_from_eos_tokens=True, + # `get_cls_from_eos_tokens` requires passing eos mask + eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None, **kwargs, )