-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain.py
222 lines (188 loc) · 10.7 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import os
import argparse
import datetime
import json
import os
import time
import torch
import torch.backends.cudnn as cudnn
from iopath.common.file_io import g_pathmgr as pathmgr
from pathlib import Path
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import util.misc as misc
import models_mae
from engine_pretrain import train_one_epoch
from util.kinetics import Kinetics
from util.misc import NativeScalerWithGradNormCount as NativeScaler
def get_args_parser():
parser = argparse.ArgumentParser("Spatiotemporal MAE pre-training", add_help=False)
parser.add_argument("--batch_size_per_gpu", default=4, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--accum_iter", default=1, type=int, help="Accumulate gradient iterations")
parser.add_argument("--save_prefix", default="", type=str, help="Prefix for saving checkpoint and log files")
# Data args
parser.add_argument("--data_dirs", type=str, default=[""], nargs="+", help="Data paths")
parser.add_argument("--datafile_dir", type=str, default="./datafiles", help="Store data files here")
parser.add_argument("--output_dir", default="./output_dir", help="Path where to save, empty for no saving")
parser.add_argument("--data_frac", default=1.0, type=float, help="Fraction of data to be used for training")
# Model parameters
parser.add_argument("--model", default="mae_vit_large_patch16", type=str, help="Name of model to train")
parser.add_argument("--img_size", default=224, type=int, help="Image size")
parser.add_argument("--mask_ratio", default=0.9, type=float, help="Masking ratio (percentage of removed patches).")
parser.add_argument("--norm_pix_loss", action="store_true", help="Use (per-patch) normalized pixels as targets for computing loss")
parser.add_argument("--resume", default="", help="Resume from checkpoint")
parser.add_argument('--compile', action='store_true', help='whether to compile the model for improved efficiency (default: false)')
parser.set_defaults(norm_pix_loss=False)
# LR related parameters
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
parser.add_argument("--lr", type=float, default=None, help="Learning rate (absolute lr)")
parser.add_argument("--blr", type=float, default=1e-3, metavar="LR", help="base learning rate: absolute_lr = base_lr * total_batch_size / 256")
parser.add_argument("--min_lr", type=float, default=1e-5, metavar="LR", help="lower lr bound for cyclic schedulers that hit 0")
parser.add_argument("--warmup_epochs", type=int, default=0, metavar="N", help="epochs to warmup LR")
# Misc
parser.add_argument("--device", default="cuda", help="Device to use for training / testing")
parser.add_argument("--clip_grad", type=float, default=None)
parser.add_argument("--start_epoch", default=0, type=int, help="Start epoch")
parser.add_argument("--num_workers", default=16, type=int)
parser.add_argument("--pin_mem", action="store_true", help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.")
parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
parser.set_defaults(pin_mem=True)
# Distributed training parameters
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--dist_url", default="env://", help="url used to set up distributed training")
# Video related configs
parser.add_argument("--decoder_embed_dim", default=512, type=int)
parser.add_argument("--decoder_depth", default=8, type=int)
parser.add_argument("--decoder_num_heads", default=16, type=int)
parser.add_argument("--t_patch_size", default=2, type=int)
parser.add_argument("--num_frames", default=16, type=int)
parser.add_argument("--checkpoint_period", default=1, type=int)
parser.add_argument("--sampling_rate", default=4, type=int)
parser.add_argument("--repeat_aug", default=4, type=int)
parser.add_argument("--no_qkv_bias", action="store_true")
parser.add_argument("--bias_wd", action="store_true")
parser.add_argument("--num_checkpoint_del", default=20, type=int)
parser.add_argument("--trunc_init", action="store_true")
parser.add_argument("--target_fps", default=30, type=int)
parser.add_argument("--jitter_scales_relative", default=[0.5, 1.0], type=float, nargs="+")
parser.add_argument("--jitter_aspect_relative", default=[0.75, 1.3333], type=float, nargs="+")
parser.add_argument("--train_jitter_scales", default=[256, 320], type=int, nargs="+")
parser.add_argument("--color_jitter", type=bool, default=False, help="Color augmentation during training")
parser.add_argument("--beta", default=None, type=float, nargs="+")
parser.add_argument("--pred_t_dim", type=int, default=16)
parser.add_argument("--cls_embed", action="store_true")
parser.set_defaults(cls_embed=True)
parser.add_argument("--sep_pos_embed", action="store_true")
parser.set_defaults(sep_pos_embed=True)
return parser
def find_mp4_files(directories):
"""Recursively search for .mp4 files in directories and their subdirectories"""
mp4_files = []
for directory in directories:
for root, _, files in os.walk(directory):
files = sorted(files)
for file in files:
if file.endswith((".mp4", ".MP4", ".mkv", ".webm")):
mp4_files.append((os.path.join(root, file), os.path.basename(root)))
return mp4_files
def write_csv(video_files, save_dir, save_name):
"""Write the .csv file with video path and subfolder index"""
with open(os.path.join(save_dir, f'{save_name}.csv'), 'w', newline='') as csvfile:
for video_file, _ in video_files:
csvfile.write(f"{video_file}, {1}\n")
def main(args):
misc.init_distributed_mode(args)
print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(", ", ",\n"))
device = torch.device(args.device)
cudnn.benchmark = True
# data pipeline
dataset_train = Kinetics(
mode="pretrain",
datafile_dir=args.datafile_dir,
sampling_rate=args.sampling_rate,
num_frames=args.num_frames,
target_fps=args.target_fps,
train_color_jitter=args.color_jitter,
train_jitter_scales=tuple(args.train_jitter_scales),
train_crop_size=args.img_size,
repeat_aug=args.repeat_aug,
jitter_aspect_relative=args.jitter_aspect_relative,
jitter_scales_relative=args.jitter_scales_relative,
)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
data_loader_train = DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True)
print(f"Sampler_train = {sampler_train}")
# effective batch size
eff_batch_size = args.batch_size_per_gpu * args.accum_iter * misc.get_world_size()
print(f"Effective batch size: {eff_batch_size} = {args.batch_size_per_gpu} batch_size_per_gpu * {args.accum_iter} accum_iter * {misc.get_world_size()} GPUs")
# effective lr
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print(f"Effective lr: {args.lr}")
# define model
model = models_mae.__dict__[args.model](**vars(args))
model.to(device)
model_without_ddp = model
print(f"Model: {model_without_ddp}")
print(f"Number of params (M): {(sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1.e6)}")
# optionally compile model
if args.compile:
model = torch.compile(model)
# wrap in ddp
model = DDP(model, device_ids=[torch.cuda.current_device()])
# following timm: set wd as 0 for bias and norm layers
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay, bias_wd=args.bias_wd)
optimizer = torch.optim._multi_tensor.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95), fused=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=90, gamma=0.1) # can use any other scheduler here
loss_scaler = NativeScaler()
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, with_optim_sched=True)
checkpoint_path = ""
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(model, data_loader_train, optimizer, device, epoch, loss_scaler, args=args)
if args.output_dir and (epoch % args.checkpoint_period == 0 or epoch + 1 == args.epochs):
checkpoint_path = misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch)
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch}
if args.output_dir and misc.is_main_process():
with pathmgr.open(f"{args.output_dir}/{args.save_prefix}_log.txt", "a") as f:
f.write(json.dumps(log_stats) + "\n")
# increment lr scheduler
scheduler.step()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
print(torch.cuda.memory_allocated())
return [checkpoint_path]
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# prepare data files
video_files = find_mp4_files(directories=args.data_dirs)
if args.data_frac < 1.0:
from math import ceil
n_vids = len(video_files)
n_vids_keep = ceil(n_vids * args.data_frac)
video_files = video_files[:n_vids_keep]
print(f"Training on {n_vids_keep} of {n_vids} video files.")
else:
n_vids = len(video_files)
print(f"Training on all {n_vids} video files.")
write_csv(video_files=video_files, save_dir=args.datafile_dir, save_name='train')
# train
main(args)