-
Notifications
You must be signed in to change notification settings - Fork 221
/
train.py
400 lines (347 loc) · 17.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import json
from time import time
import argparse
import logging
import os
from pathlib import Path
import math
import numpy as np
from PIL import Image
from copy import deepcopy
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers.optimization import get_scheduler
from accelerate.utils import DistributedType
from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model
from peft.utils import get_peft_model_state_dict
from huggingface_hub import snapshot_download
from safetensors.torch import save_file
from diffusers.models import AutoencoderKL
from OmniGen import OmniGen, OmniGenProcessor
from OmniGen.train_helper import DatasetFromJson, TrainDataCollator
from OmniGen.train_helper import training_losses
from OmniGen.utils import (
create_logger,
update_ema,
requires_grad,
center_crop_arr,
crop_arr,
vae_encode,
vae_encode_list
)
def main(args):
# Setup accelerator:
from accelerate import DistributedDataParallelKwargs as DDPK
kwargs = DDPK(find_unused_parameters=False)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=args.results_dir,
kwargs_handlers=[kwargs],
)
device = accelerator.device
accelerator.init_trackers("tensorboard_log", config=args.__dict__)
# Setup an experiment folder:
os.makedirs(args.results_dir, exist_ok=True)
logger = create_logger(args.results_dir)
checkpoint_dir = f"{args.results_dir}/checkpoints" # Stores saved model checkpoints
if accelerator.is_main_process:
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"Experiment directory created at {args.results_dir}")
json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w'))
# Create model:
if not os.path.exists(args.model_name_or_path):
cache_folder = os.getenv('HF_HUB_CACHE')
args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
logger.info(f"Downloaded model to {args.model_name_or_path}")
model = OmniGen.from_pretrained(args.model_name_or_path)
model.llm.config.use_cache = False
model.llm.gradient_checkpointing_enable()
model = model.to(device)
if args.vae_path is None:
vae_path = os.path.join(args.model_name_or_path, "vae")
if os.path.exists(vae_path):
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
else:
logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF")
logger.info("If you have VAE in local folder, please specify the path with --vae_path")
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path).to(device)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(dtype=torch.float32)
model.to(weight_dtype)
processor = OmniGenProcessor.from_pretrained(args.model_name_or_path)
requires_grad(vae, False)
if args.use_lora:
if accelerator.distributed_type == DistributedType.FSDP:
raise NotImplementedError("FSDP does not support LoRA")
requires_grad(model, False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_rank,
init_lora_weights="gaussian",
target_modules=["qkv_proj", "o_proj"],
)
model.llm.enable_input_require_grads()
model = get_peft_model(model, transformer_lora_config)
model.to(weight_dtype)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay)
else:
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay)
ema = None
if args.use_ema:
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
requires_grad(ema, False)
# Setup data:
crop_func = crop_arr
if not args.keep_raw_resolution:
crop_func = center_crop_arr
image_transform = transforms.Compose([
transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = DatasetFromJson(json_file=args.json_file,
image_path=args.image_path,
processer=processor,
image_transform=image_transform,
max_input_length_limit=args.max_input_length_limit,
condition_dropout_prob=args.condition_dropout_prob,
keep_raw_resolution=args.keep_raw_resolution
)
collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution)
loader = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=args.batch_size_per_device,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
prefetch_factor=2,
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(dataset):,}")
num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps)
max_train_steps = args.epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=opt,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=max_train_steps * args.gradient_accumulation_steps,
)
# Prepare models for training:
model.train() # important! This enables embedding dropout for classifier-free guidance
if ema is not None:
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
ema.eval() # EMA model should always be in eval mode
if ema is not None:
model, ema = accelerator.prepare(model, ema)
else:
model = accelerator.prepare(model)
opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler)
# Variables for monitoring/logging purposes:
train_steps, log_steps = 0, 0
running_loss = 0
start_time = time()
if accelerator.is_main_process:
logger.info(f"Training for {args.epochs} epochs...")
for epoch in range(args.epochs):
if accelerator.is_main_process:
logger.info(f"Beginning epoch {epoch}...")
for data in loader:
with accelerator.accumulate(model):
with torch.no_grad():
output_images = data['output_images']
input_pixel_values = data['input_pixel_values']
if isinstance(output_images, list):
output_images = vae_encode_list(vae, output_images, weight_dtype)
if input_pixel_values is not None:
input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype)
else:
output_images = vae_encode(vae, output_images, weight_dtype)
if input_pixel_values is not None:
input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype)
# TODO: weighted loss for image editting
# patch_weight = []
# for i in range(len(output_images)):
# temp_x = output_images[i]
# w = torch.ones_like(temp_x).detach()
# if temp_x is for editing task:
# # Find the input image corresponding to the output image. We store the index in need_edit_imgs
# input_x = input_pixel_values[need_edit_imgs[i]]
# diff = torch.abs(temp_x - input_x).detach() # no grandient for weight
# diff_mean = torch.mean(diff)
# if diff_mean < 0.001:
# # The difference between the input and output images is too small, so we suspect there might be an issue with this data. We discard the image by setting its weight to zero.
# w = w * 0
# elif diff_mean <= 0.8:
# weight = 1 / (diff_mean + 1e-6)
# weight = max(min(weight, 64), 5) #crop the weight
# w[diff>0.3] = weight #assign the weight to the pixels which are different in input and output
# else:
# # The difference between the input and output images is significant enough, so there's no need to reinforce the loss.
# pass
# patch_weight.append(w)
model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False)
loss_dict = training_losses(model, output_images, model_kwargs)
loss = loss_dict["loss"].mean()
running_loss += loss.item()
accelerator.backward(loss)
if args.max_grad_norm is not None and accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
opt.step()
lr_scheduler.step()
opt.zero_grad()
log_steps += 1
train_steps += 1
accelerator.log({"training_loss": loss.item()}, step=train_steps)
if train_steps % args.gradient_accumulation_steps == 0:
if accelerator.sync_gradients and ema is not None:
update_ema(ema, model)
if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / accelerator.num_processes
if accelerator.is_main_process:
cur_lr = opt.param_groups[0]["lr"]
logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}")
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time()
if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
if accelerator.distributed_type == DistributedType.FSDP:
state_dict = accelerator.get_state_dict(model)
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
else:
if not args.use_lora:
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
if accelerator.is_main_process:
if args.use_lora:
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
os.makedirs(checkpoint_path, exist_ok=True)
if hasattr(model, "module"):
model.module.save_pretrained(checkpoint_path)
else:
model.save_pretrained(checkpoint_path)
else:
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
os.makedirs(checkpoint_path, exist_ok=True)
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
processor.text_tokenizer.save_pretrained(checkpoint_path)
model.llm.config.save_pretrained(checkpoint_path)
if ema_state_dict is not None:
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema"
os.makedirs(checkpoint_path, exist_ok=True)
torch.save(ema_state_dict, os.path.join(checkpoint_path, "model.pt"))
processor.text_tokenizer.save_pretrained(checkpoint_path)
model.llm.config.save_pretrained(checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
if dist.is_available() and dist.is_initialized():
dist.barrier()
accelerator.end_training()
model.eval()
if accelerator.is_main_process:
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--results_dir", type=str, default="results")
parser.add_argument("--model_name_or_path", type=str, default="OmniGen")
parser.add_argument("--json_file", type=str)
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--batch_size_per_device", type=int, default=1)
parser.add_argument("--vae_path", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--log_every", type=int, default=100)
parser.add_argument("--ckpt_every", type=int, default=20000)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--max_input_length_limit", type=int, default=1024)
parser.add_argument("--condition_dropout_prob", type=float, default=0.1)
parser.add_argument("--adam_weight_decay", type=float, default=0.0)
parser.add_argument(
"--keep_raw_resolution",
action="store_true",
help="multiple_resolutions",
)
parser.add_argument("--max_image_size", type=int, default=1344)
parser.add_argument(
"--use_lora",
action="store_true",
)
parser.add_argument(
"--lora_rank",
type=int,
default=8
)
parser.add_argument(
"--use_ema",
action="store_true",
help="Whether or not to use ema.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
args = parser.parse_args()
assert args.max_image_size % 16 == 0, "Image size must be divisible by 16."
main(args)