-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
161 lines (136 loc) · 7.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from peft import PeftConfig, PeftModel, LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed
import os
import hydra
import torch.distributed as dist
import torch.multiprocessing as mp
from omegaconf import OmegaConf, DictConfig
import trainers
import wandb
import json
import socket
from typing import Optional, Set
OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs))
def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None, start_step: int = 0):
"""Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer)."""
if 'FSDP' in config.trainer:
init_distributed(rank, world_size, port=config.fsdp_port)
if config.debug:
wandb.init = lambda *args, **kwargs: None
wandb.log = lambda *args, **kwargs: None
if rank == 0 and config.wandb.enabled:
wandb.login(key="c0a4d9df0a801da1b53257f0c63d8283af4ae526")
os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs)
wandb.init(
entity=config.wandb.entity,
project=config.wandb.project,
config=OmegaConf.to_container(config),
dir=get_local_dir(config.local_dirs),
name=config.exp_name,
)
TrainerClass = getattr(trainers, config.trainer)
print(f'Creating trainer on process {rank} with world size {world_size}')
trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size, start_step=start_step)
trainer.train()
trainer.save()
@hydra.main(version_base=None, config_path="config", config_name="config")
def main(config: DictConfig):
"""Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es)."""
# Resolve hydra references, e.g. so we don't re-compute the run directory
OmegaConf.resolve(config)
missing_keys: Set[str] = OmegaConf.missing_keys(config)
if missing_keys:
raise ValueError(f"Got missing keys in config:\n{missing_keys}")
if config.eval_every % config.batch_size != 0:
print('WARNING: eval_every must be divisible by batch_size')
print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size)
config.eval_every = config.eval_every - config.eval_every % config.batch_size
print(OmegaConf.to_yaml(config))
config_path = os.path.join(config.local_run_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config, f)
print('=' * 80)
print(f'Writing to {socket.gethostname()}:{config.local_run_dir}')
print('=' * 80)
os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs)
print('building policy')
model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {}
if config.quantization == 4:
print("applying quantization...")
model_kwargs["quantization_config"] = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
policy_dtype = getattr(torch, config.model.policy_dtype)
policy = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True,
torch_dtype=policy_dtype, trust_remote_code=True, **model_kwargs)
if config.model.peft.enabled:
policy = PeftModel.from_pretrained(policy, config.model.peft.model_name, is_trainable=config.model.peft.trainable)
if config.lora.enabled:
lora_config = LoraConfig(
r=config.lora.r,
lora_alpha=config.lora.alpha,
target_modules=config.lora.target_modules,
lora_dropout=config.lora.dropout,
bias=config.lora.bias,
task_type=TaskType.CAUSAL_LM
)
policy = prepare_model_for_int8_training(policy)
policy = get_peft_model(policy, lora_config)
policy.print_trainable_parameters()
disable_dropout(policy)
if config.loss.name == 'dpo':
print('building reference model')
reference_model_dtype = getattr(torch, config.model.reference_dtype)
if config.quantization != 0 and config.quantization_reference == 4:
print("applying quantization...")
model_kwargs["quantization_config"] = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
reference_model = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, trust_remote_code=True, **model_kwargs)
# We do not need lora config here, as reference_model is only used for inference with current values.
if config.model.peft.enabled:
reference_model = PeftModel.from_pretrained(reference_model, config.model.peft.model_name)
disable_dropout(reference_model)
else:
reference_model = None
start_step = 0
if not config.model.peft.enabled and config.model.archive is not None:
state_dict = torch.load(config.model.archive, map_location='cpu')
step, metrics = state_dict['step_idx'], state_dict['metrics']
print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}')
policy.load_state_dict(state_dict['state'])
if config.loss.name == 'dpo':
if config.reference_model_path:
ref_state_dict = torch.load(config.reference_model_path, map_location='cpu')
reference_model.load_state_dict(ref_state_dict["state"])
else:
reference_model.load_state_dict(state_dict['state'])
if config.clean_chkpt_after_load:
os.remove(config.model.archive)
# If optimizer_path or scheduler_path is not set, then it is not resuming.
# It is simply starting DPO training.
# However, if set, it is resuming and we need to set the start_step
if config.optimizer_path is not None or config.scheduler_path is not None:
start_step = step
print('loaded pre-trained weights')
if 'FSDP' in config.trainer:
world_size = torch.cuda.device_count()
print('starting', world_size, 'processes for FSDP training')
mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model, start_step), join=True)
else:
print('starting single-process worker')
worker_main(0, 1, config, policy, reference_model, start_step)
if __name__ == '__main__':
main()