-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_RAPID.py
86 lines (69 loc) · 3.09 KB
/
train_RAPID.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import os
import random
import numpy as np
import pytorch_lightning as pl
import torch
import yaml
from models.rapid import RAPID
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from utils.argparser import init_args
from utils.dataset import get_dataset_and_loader
from utils.ema import EMACallback
from pytorch_lightning.loggers import TensorBoardLogger
if __name__== '__main__':
# Parse command line arguments and load config file
parser = argparse.ArgumentParser(description='Pose_AD_Experiment')
parser.add_argument('-c', '--config', type=str, required=True,
default='/your_default_config_file_path')
args = parser.parse_args()
config_path = args.config
args = yaml.load(open(args.config), Loader=yaml.FullLoader)
args = argparse.Namespace(**args)
args = init_args(args)
# Save config file to ckpt_dir
os.system(f'cp {config_path} {os.path.join(args.ckpt_dir, "config.yaml")}')
# Set seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
pl.seed_everything(args.seed)
tensorboard_logger = TensorBoardLogger(save_dir='driver_distraction/logs/', name='temp_experiment')
# Set callbacks and logger
monitored_metric = 'loss_noise'
metric_mode = 'min'
callbacks = [ModelCheckpoint(dirpath=args.ckpt_dir, save_top_k=2,
monitor=monitored_metric,
mode=metric_mode)]
callbacks += [EMACallback()] if args.use_ema else []
loggers = [tensorboard_logger]
if args.use_wandb:
callbacks += [LearningRateMonitor(logging_interval='step')]
wandb_logger = WandbLogger(project=args.project_name, group=args.group_name, entity=args.wandb_entity,
name=args.dir_name, config=vars(args), log_model='all')
else:
wandb_logger = False
# Get dataset and loaders
_, train_loader= get_dataset_and_loader(args, split=args.split)
# Initialize model and trainer
# model = RAPIDlatent(args) if hasattr(args, 'diffusion_on_latent') else RAPID(args)
model = RAPID(args)
trainer = pl.Trainer(accelerator=args.accelerator, devices=args.devices, default_root_dir=args.ckpt_dir, max_epochs=args.n_epochs,
logger=wandb_logger, callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters=False),
log_every_n_steps=20, num_sanity_val_steps=0, deterministic=True)
trainer = pl.Trainer(
accelerator=args.accelerator,
devices=args.devices,
default_root_dir=args.ckpt_dir,
max_epochs=args.n_epochs,
logger=loggers,
callbacks=callbacks,
strategy=DDPStrategy(find_unused_parameters=False),
log_every_n_steps=20,
num_sanity_val_steps=0,
deterministic=True
)
# Train the model
trainer.fit(model=model, train_dataloaders=train_loader)