Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference with different LoRA adapters in the same batch does not use the correct module_to_save classifier #1960

Open
2 of 4 tasks
saeid93 opened this issue Jul 26, 2024 · 10 comments

Comments

@saeid93
Copy link
Contributor

saeid93 commented Jul 26, 2024

System Info

Python 3.11.9
transformers==4.40.2
peft==0.11.2

Who can help?

@BenjaminBossan
I'm interested in using Inference with different LoRA adapters in the same batch
feature, with a separate last layer classifier for each LoRA adapter. However, during the inference when having multiple requests destined to different adapter the peft library uses the active adapter for every request rather than the appropriate LoRA weight.
I should note that this problem only happens for ModuleToSave layers and the other layers (e.g. the base models) use the correct LoRA weights per each request.

This is because the per-request LoRA weights are not passed to ModulesToSaveWrapper class and the inference is always done using the active_adapter module to save:

return self.modules_to_save[self.active_adapter](*args, **kwargs)

The solution would be to pass the adapter_names all the way down to the forward function of ModuleToSaveWrapper **kwargs and write a similar logic as

def _mixed_batch_forward(

for sending sub_batches of similar adapters together to each appropriate classifier.

However, I see that you are excluding the special_peft_forward_args:

kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}

possibly to avoid interfering with the base_models forward function e.g.
https://github.com/huggingface/transformers/blob/5f841c74b62754f186a8c06a684d491524b7bc03/src/transformers/models/vit/modeling_vit.py#L813

I was able to solve this by modifying the mentioned functions but since it is a bug I think it can also be considered for being solved in the upstream or be mentioned in the documentation as another caveat. @stevhliu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

The case is pretty similar to the documentation example when having a classifier module on top. The code returns no error but I was able to notice this by observing difference in accuracies and tracking the root in the peft library as I mentioned above. Please let me know if more information is needed.

Expected behavior

Each adapter using their own classifier rather than active classfier.

@BenjaminBossan
Copy link
Member

Thanks a lot for reporting this error and your great investigation. Indeed, this should ideally work out of the box. You mentioned:

I was able to solve this by modifying the mentioned functions

Could you please share the code to achieve this?

@saeid93
Copy link
Contributor Author

saeid93 commented Jul 26, 2024

Glad to be of any help!
Please find the code below, it is just a hack to dynamically patch the modifications to the library.
The rest of the code is just using the below class and functions rather than Peft and transformers classes. I have marked changes with HEREs on the code.

BTW, I'm happy to investigate further for fixing this with a pull request, however, it will take some time. If there is a timeline for fixing it then I leave it to you.

from typing import Any, Optional, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from peft.peft_model import PeftModel
from transformers.modeling_outputs import ImageClassifierOutput
from transformers import ViTForImageClassification, MobileViTForImageClassification
from functools import partial

class PeftModelFixed(PeftModel):
    def forward(self, *args: Any, **kwargs: Any):
        """
        Forward pass of the model.
        """
        with self._enable_peft_forward_hooks(*args, **kwargs):
            # HERE removed this to avoid mixing
            # kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
            return self.get_base_model()(*args, **kwargs)


class ViTForImageClassificationFixed(ViTForImageClassification):
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs # HERE added kwargs
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # HERE this was changed with having adapters
        logits = self.classifier(sequence_output[:, 0, :], adapter_names=kwargs["adapter_names"])

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

def peftforward(self, *args, **kwargs):
    if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
        return self.original_module(*args, **kwargs)

    # HERE changed to support LoRA
    adapter_names = kwargs["adapter_names"]
    kwargs = {}
    batch = args[0]
    unique_adapters = set(adapter_names)
    sub_batch_indices_list = []
    for adapter in unique_adapters:
        sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

    results = [0 for i in range(len(batch))]
    for i, active_adapter in enumerate(unique_adapters):
        sub_batch = batch[sub_batch_indices_list[i]]
        output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs)
        for index, j in enumerate(sub_batch_indices_list[i]):
            results[j] = output[index]
    return torch.stack(results)

def change_forward_dynamically(model: PeftModel):
    # HERE model is passed here to dynamically change the last layer
    model.classifier.forward = partial(peftforward, model.classifier)
    return model

@BenjaminBossan
Copy link
Member

Thanks a lot. I would gladly accept a PR for this fix, that would be fantastic. It should probably be much easier to add than your fix, as it could be fixed directly where needed instead of patching.

There is no strict timeline, we just had a release, so the next one would still be a bit in the future.

@saeid93
Copy link
Contributor Author

saeid93 commented Jul 26, 2024

Awesome, I'll work on it when I get a chance.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 25, 2024

This is pending approval of #1990. Sending this to remove the automatic stale mark.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

The ViT part is still open, so not stale.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

not stale

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants