Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
junshijun committed Oct 7, 2024
1 parent 05c4d1a commit 3b242a6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/finetune_all_case_chatglm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

declare -a test_case_yamls=(
"demo/lora/lora_case_2.yaml"
"demo/checkpoint/checkpoint_case_3.yaml"
"demo/checkpoint/checkpoint_case_4.yaml"
"demo/loraplus/loraplus_case_2.yaml"
"demo/vera/vera_case_2.yaml"
"demo/dora/dora_case_2.yaml"
"demo/dpo/dpo_case_4.yaml"
"demo/dpo/dpo_case_5.yaml"
"demo/dpo/dpo_case_6.yaml"
"demo/cit/cit_case_2.yaml"
)

set -x
for test_case in "${test_case_yamls[@]}"
do
python mlora_train.py --base_model $1 --config ${test_case} --precision bf16 --model_type chatglm
done
63 changes: 63 additions & 0 deletions tests/inference_all_case_chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import argparse
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

G_TEST_ADAPTERS = [
# lora adapter
"adapters/lora_sft_0",
"adapters/lora_sft_1",
# loraplus adapter
"adapters/loraplus_sft_0",
"adapters/loraplus_sft_1",
# dpo adapter
"adapters/lora_base_dpo",
"adapters/lora_sft_dpo",
"adapters/loraplus_sft_dpo",
# cit adapter
"adapters/lora_cit",
"adapters/loraplus_cit"
]


def get_cmd_args():
parser = argparse.ArgumentParser(description='mLoRA test function')
parser.add_argument('--base_model', type=str, required=True,
help='Path to or name of base model')
return parser.parse_args()


if __name__ == "__main__":
args = get_cmd_args()

model_path = args.base_model

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)


query = "What is mLoRA?"
device = "cuda"

inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = inputs.to(device)

model.to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}

for adapter in G_TEST_ADAPTERS:

peft_model = PeftModel.from_pretrained(model, adapter)


with torch.no_grad():
#print(peft_model.generate(**inputs))
outputs = peft_model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

0 comments on commit 3b242a6

Please sign in to comment.