forked from besarthoxhaj/lora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
67 lines (56 loc) · 1.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import transformers
import model
import data_onion as data
import os
import wandb
import torch
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
is_ddp = int(os.environ.get("WORLD_SIZE", 1)) != 1
m = model.get_model().to(device)
ds = data.TrainDataset()
collator = transformers.DataCollatorForSeq2Seq(ds.tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
# import torch
# import peft
# adapters_weights = torch.load("./output/checkpoint-800/adapter_model.bin")
# peft.set_peft_model_state_dict(m, adapters_weights)
title = str(input("Title: "))
num_train_epochs = 1
learning_rate = 3e-4
optim="adamw_torch"
batch_size = 4
wandb.init(
project="Llama",
name= title,
config={
"epochs": num_train_epochs,
"learning_rate": learning_rate,
"optim": optim,
"batch_size": batch_size,
"device": device.type,
}
)
trainer = transformers.Trainer(
model=m,
train_dataset=ds,
data_collator=collator,
args=transformers.TrainingArguments(
per_device_train_batch_size=4,
num_train_epochs=1,
learning_rate=3e-4,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=200,
output_dir="./output",
save_total_limit=3,
report_to="wandb",
ddp_find_unused_parameters=False if is_ddp else None
),
)
m.config.use_cache = False
trainer.train()
m.save_pretrained("./weights")