diff --git a/code/train_llama2_lowrank_lora.py b/code/train_llama2_lowrank_lora.py index d5814a7..61026e5 100644 --- a/code/train_llama2_lowrank_lora.py +++ b/code/train_llama2_lowrank_lora.py @@ -156,8 +156,6 @@ def train(model, tokenizer, dataset_train, dataset_val, max_tokens=256, batch_si model.eval() -print(dataset_name.value, '| perplexity', compute_perplexity(model=model, tokenizer=tokenizer, - predictions=[s['text'] for s in dataset_val], - batch_size=1, max_length=512)) +print(dataset_name.value, '| perplexity', compute_perplexity(model=model, tokenizer=tokenizer, predictions=[s['text'] for s in dataset_val], batch_size=1, max_length=512))