Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Tuning does not work with T5-3B #485

Closed
saparina opened this issue Jan 25, 2023 · 4 comments · Fixed by #621
Closed

Prefix Tuning does not work with T5-3B #485

saparina opened this issue Jan 25, 2023 · 4 comments · Fixed by #621
Assignees
Labels
bug Something isn't working do-not-stale This issue won't be automatically staled and closed after 90 days

Comments

@saparina
Copy link

Environment info

  • adapter-transformers version: 3.1.0

Information

Model I am using: T5-3B

Adapter setup I am using (if any): Prefix Tuning

To reproduce

I believe this bug occurs every time T5-3B and Prefix Tuning are used, for example, it can be reproduced by running the summarization example:

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate \
    --train_adapter \
    --adapter_config prefix_tuning

Error is raised when prefix embeddings should be concatenated with keys and values:

  File "/home/hpcsapa1/.conda/envs/adapters-torch1.9/lib/python3.9/site-packages/transformers/adapters/prefix_tuning.py", line 329, in forward
    key_states = torch.cat([prefix_keys, key_states], dim=2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 128 for tensor number 1 in the list.

Expected behavior

There is a mismatch in the sizes of prefix embeddings and keys and values. I think this is because the prefix size is based on hidden_size which is defined as d_model for T5 models while it should be num_heads * d_kv. For T5-Small, T5-Base and T5-Large d_model == num_heads * d_kv (and these models work fine) but this is not true for T5-3B where d_model=1024 and num_heads=32, d_kv=128 .

@saparina saparina added the bug Something isn't working label Jan 25, 2023
@calpt
Copy link
Member

calpt commented Jan 26, 2023

Hey @saparina, thanks for bringing this up.

Your explanation of the issue makes a lot of sense, we'll look into a posible fix for prefix tuning. It's a bit odd however that the constraint d_model == num_heads * d_kv is not fulfilled by all models as the HuggingFace documentation for T5 explicitly states that this is required to be the case: https://huggingface.co/docs/transformers/main/en/model_doc/t5#transformers.T5Config.d_kv

@adapter-hub-bert
Copy link
Member

This issue has been automatically marked as stale because it has been without activity for 90 days. This issue will be closed in 14 days unless you comment or remove the stale label.

@calpt calpt removed the Stale label Apr 27, 2023
@lenglaender lenglaender added the do-not-stale This issue won't be automatically staled and closed after 90 days label May 15, 2023
@vijetadeshpande
Copy link

Is there any update on this issue?

@calpt
Copy link
Member

calpt commented Dec 20, 2023

Sorry for the delay on this, it should be fixed once #621 is merged.

@calpt calpt closed this as completed in #621 Jan 5, 2024
calpt added a commit that referenced this issue Jan 5, 2024
Fixes #485.

Allows passing head dim (`n_embd_per_head`) explicitly to prefix tuning
to accomodate models where head dim is not equal to hidden dim /
n_heads.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working do-not-stale This issue won't be automatically staled and closed after 90 days
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants