-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- deberta - debertav2 - distilbert - electra - encoder-decoder - llama - mbart - mistral - mt5 - plbart - roberta
- Loading branch information
Showing
12 changed files
with
323 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from transformers import DebertaConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class DebertaAdapterTestBase(TextAdapterTestBase): | ||
config_class = DebertaConfig | ||
config = make_config( | ||
DebertaConfig, | ||
hidden_size=32, | ||
num_hidden_layers=5, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
relative_attention=True, | ||
pos_att_type="p2c|c2p", | ||
) | ||
tokenizer_name = "microsoft/deberta-base" | ||
|
||
def test_parallel_training_lora(self): | ||
self.skipTest("Not supported for DeBERTa") | ||
|
||
|
||
method_tests = generate_method_tests(DebertaAdapterTestBase) | ||
|
||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from transformers import DebertaV2Config | ||
|
||
from .utils import * | ||
|
||
|
||
class DebertaV2AdapterTestBase(TextAdapterTestBase): | ||
config_class = DebertaV2Config | ||
config = make_config( | ||
DebertaV2Config, | ||
hidden_size=32, | ||
num_hidden_layers=5, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
relative_attention=True, | ||
pos_att_type="p2c|c2p", | ||
) | ||
tokenizer_name = "microsoft/deberta-v3-base" | ||
|
||
|
||
method_tests = generate_method_tests(DebertaV2AdapterTestBase) | ||
|
||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from transformers import DistilBertConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class DistilBertAdapterTestBase(TextAdapterTestBase): | ||
config_class = DistilBertConfig | ||
config = make_config( | ||
DistilBertConfig, | ||
dim=32, | ||
n_layers=4, | ||
n_heads=4, | ||
hidden_dim=37, | ||
) | ||
tokenizer_name = "distilbert-base-uncased" | ||
|
||
|
||
method_tests = generate_method_tests(DistilBertAdapterTestBase) | ||
|
||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from transformers import ElectraConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class ElectraAdapterTestBase(TextAdapterTestBase): | ||
config_class = ElectraConfig | ||
config = make_config( | ||
ElectraConfig, | ||
# vocab_size=99, | ||
hidden_size=32, | ||
num_hidden_layers=5, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
) | ||
tokenizer_name = "google/electra-base-generator" | ||
|
||
|
||
method_tests = generate_method_tests(ElectraAdapterTestBase) | ||
|
||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from adapters import init | ||
from transformers import AutoModelForSeq2SeqLM, BertConfig, EncoderDecoderConfig, EncoderDecoderModel | ||
|
||
from .utils import * | ||
|
||
|
||
class EncoderDecoderAdapterTestBase(TextAdapterTestBase): | ||
model_class = EncoderDecoderModel | ||
config_class = EncoderDecoderConfig | ||
config = staticmethod( | ||
lambda: EncoderDecoderConfig.from_encoder_decoder_configs( | ||
BertConfig( | ||
hidden_size=32, | ||
num_hidden_layers=4, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
), | ||
BertConfig( | ||
hidden_size=32, | ||
num_hidden_layers=4, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
is_decoder=True, | ||
add_cross_attention=True, | ||
), | ||
) | ||
) | ||
tokenizer_name = "bert-base-uncased" | ||
do_run_train_tests = False | ||
|
||
def test_generation(self): | ||
model = AutoModelForSeq2SeqLM.from_config(self.config()) | ||
init(model) | ||
model.add_adapter("test", config="pfeiffer") | ||
model.set_active_adapters("test") | ||
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) | ||
|
||
text = "This is a test sentence." | ||
input_ids = tokenizer(text, return_tensors="pt").input_ids | ||
|
||
generated_ids = model.generate(input_ids, bos_token_id=100) | ||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
self.assertNotEqual("", generated_text) | ||
|
||
def test_invertible_adapter_with_head(self): | ||
"""This test class is copied and adapted from the identically-named test in test_adapter_heads.py.""" | ||
raise self.skipTest("AutoModelForSeq2SeqLM does not support using invertible adapters.") | ||
|
||
def test_adapter_fusion_save_with_head(self): | ||
# This test is not applicable to the encoder-decoder model since it has no heads. | ||
self.skipTest("Not applicable to the encoder-decoder model.") | ||
|
||
def test_forward_with_past(self): | ||
# This test is not applicable to the encoder-decoder model since it has no heads. | ||
self.skipTest("Not applicable to the encoder-decoder model.") | ||
|
||
def test_output_adapter_gating_scores_unipelt(self): | ||
# TODO currently not supported | ||
self.skipTest("Not implemented.") | ||
|
||
def test_output_adapter_fusion_attentions(self): | ||
# TODO currently not supported | ||
self.skipTest("Not implemented.") | ||
|
||
|
||
test_methods = generate_method_tests( | ||
EncoderDecoderAdapterTestBase, | ||
excluded_tests=["Heads", "ConfigUnion", "Embeddings", "Composition", "PromptTuning", "ClassConversion"], | ||
) | ||
|
||
for test_class_name, test_class in test_methods.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from transformers.models.llama.configuration_llama import LlamaConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class LlamaAdapterTestBase(TextAdapterTestBase): | ||
config_class = LlamaConfig | ||
config = make_config( | ||
LlamaConfig, | ||
hidden_size=32, | ||
num_hidden_layers=5, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
pad_token_id=0, | ||
) | ||
tokenizer_name = "openlm-research/open_llama_13b" | ||
|
||
|
||
method_tests = generate_method_tests(LlamaAdapterTestBase, excluded_tests=["PromptTuning"]) | ||
|
||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class | ||
|
||
|
||
@require_torch | ||
class ClassConversion( | ||
ModelClassConversionTestMixin, | ||
LlamaAdapterTestBase, | ||
unittest.TestCase, | ||
): | ||
def test_conversion_question_answering_model(self): | ||
raise self.skipTest("We don't support the Llama QA model.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from transformers import MBartConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class MBartAdapterTestBase(TextAdapterTestBase): | ||
config_class = MBartConfig | ||
config = make_config( | ||
MBartConfig, | ||
d_model=16, | ||
encoder_layers=2, | ||
decoder_layers=2, | ||
encoder_attention_heads=4, | ||
decoder_attention_heads=4, | ||
encoder_ffn_dim=4, | ||
decoder_ffn_dim=4, | ||
vocab_size=250027, | ||
) | ||
tokenizer_name = "facebook/mbart-large-cc25" | ||
|
||
|
||
method_tests = generate_method_tests( | ||
MBartAdapterTestBase, excluded_tests=["ConfigUnion", "Embeddings", "PromptTuning"] | ||
) | ||
for test_class_name, test_class in method_tests.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from transformers.models.mistral.configuration_mistral import MistralConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class MistralAdapterTestBase(TextAdapterTestBase): | ||
config_class = MistralConfig | ||
config = make_config( | ||
MistralConfig, | ||
hidden_size=32, | ||
num_hidden_layers=5, | ||
num_attention_heads=8, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
hidden_dropout_prob=0.1, | ||
pad_token_id=0, | ||
) | ||
tokenizer_name = "HuggingFaceH4/zephyr-7b-beta" | ||
|
||
|
||
test_methods = generate_method_tests(MistralAdapterTestBase, excluded_tests=["PromptTuning", "ConfigUnion"]) | ||
|
||
for test_class_name, test_class in test_methods.items(): | ||
globals()[test_class_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from transformers import MT5Config | ||
|
||
from .utils import * | ||
|
||
|
||
@require_torch | ||
class MT5AdapterTestBase(TextAdapterTestBase): | ||
config_class = MT5Config | ||
config = make_config( | ||
MT5Config, | ||
d_model=16, | ||
num_layers=2, | ||
num_decoder_layers=2, | ||
num_heads=4, | ||
d_ff=4, | ||
d_kv=16 // 4, | ||
tie_word_embeddings=False, | ||
decoder_start_token_id=0, | ||
) | ||
tokenizer_name = "google/mt5-base" | ||
|
||
|
||
method_tests = generate_method_tests(MT5AdapterTestBase, excluded_tests=["PromptTuning", "ConfigUnion"]) | ||
|
||
for test_name, test_class in method_tests.items(): | ||
globals()[test_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from transformers import PLBartConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class PLBartAdapterTestBase(TextAdapterTestBase): | ||
config_class = PLBartConfig | ||
config = make_config( | ||
PLBartConfig, | ||
d_model=16, | ||
encoder_layers=2, | ||
decoder_layers=2, | ||
encoder_attention_heads=4, | ||
decoder_attention_heads=4, | ||
encoder_ffn_dim=4, | ||
decoder_ffn_dim=4, | ||
scale_embedding=False, # Required for embedding tests | ||
) | ||
tokenizer_name = "uclanlp/plbart-base" | ||
|
||
|
||
method_tests = generate_method_tests(PLBartAdapterTestBase, excluded_tests=["PromptTuning"]) | ||
|
||
for test_name, test_class in method_tests.items(): | ||
globals()[test_name] = test_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from transformers import RobertaConfig | ||
|
||
from .utils import * | ||
|
||
|
||
class RobertaAdapterTestBase(TextAdapterTestBase): | ||
config_class = RobertaConfig | ||
config = make_config( | ||
RobertaConfig, | ||
hidden_size=32, | ||
num_hidden_layers=4, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
vocab_size=50265, | ||
) | ||
tokenizer_name = "roberta-base" | ||
|
||
|
||
method_tests = generate_method_tests(RobertaAdapterTestBase) | ||
|
||
for test_name, test_class in method_tests.items(): | ||
globals()[test_name] = test_class |