From 6608e7b22805ca8774168954b27c0bf93c163174 Mon Sep 17 00:00:00 2001 From: zjkjzj Date: Tue, 3 Oct 2023 22:40:08 +0800 Subject: [PATCH] feat(train): add Automatic Mixed Precision (AMP) --- train_rpnet.py | 26 +++++++++++++------------- train_wr2.py | 25 +++++++++++++------------ 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/train_rpnet.py b/train_rpnet.py index 2251e97..0c4be1c 100644 --- a/train_rpnet.py +++ b/train_rpnet.py @@ -67,6 +67,9 @@ def adjust_learning_rate(lr, warmup_epoch, optimizer, epoch: int, step: int, len def train(train_root, val_root, batch_size, output, device, wr2_pretrained): + if RANK in {-1, 0} and not os.path.exists(output): + os.makedirs(output) + LOGGER.info("=> Create Model") # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = RPNet(device=device, wr2_pretrained=wr2_pretrained).to(device) @@ -97,6 +100,8 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained): LOGGER.info("=> Start training") t0 = time.time() + amp = True + scaler = torch.cuda.amp.GradScaler(enabled=amp) # DDP mode cuda = device.type != 'cpu' @@ -120,17 +125,16 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained): images = images.to(device) targets = targets.to(device) - outputs = model(images) - - loss = criterion(outputs, targets) - # if RANK != -1: - # loss *= WORLD_SIZE # gradient averaged between devices in DDP mode - loss.backward() + with torch.cuda.amp.autocast(amp): + outputs = model(images) + loss = criterion(outputs, targets) + scaler.scale(loss).backward() if epoch <= warmup_epoch: adjust_learning_rate(learn_rate, warmup_epoch, optimizer, epoch - 1, idx, len(train_dataloader)) - optimizer.step() + scaler.step(optimizer) # optimizer.step + scaler.update() optimizer.zero_grad() if RANK in {-1, 0}: @@ -157,14 +161,11 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained): ap, acc = ccpd_evaluator.result() LOGGER.info(f"AP:{ap * 100:.3f} ACC: {acc * 100:.3f}") scheduler.step() + torch.cuda.empty_cache() LOGGER.info(f'\n{epochs} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') def main(opt): - output = opt.output - if not os.path.exists(output): - os.makedirs(output) - # DDP mode device = select_device(opt.device, batch_size=opt.batch_size) if LOCAL_RANK != -1: @@ -176,10 +177,9 @@ def main(opt): device = torch.device('cuda', LOCAL_RANK) dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") - # init_seeds(opt.seed + 1 + RANK, deterministic=True) init_seeds(opt.seed + 1 + RANK, deterministic=False) # LOGGER.info(f"LOCAL_RANK: {LOCAL_RANK} RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}") - train(opt.train_root, opt.val_root, opt.batch_size, output, device, opt.wr2_pretrained) + train(opt.train_root, opt.val_root, opt.batch_size, opt.output, device, opt.wr2_pretrained) if __name__ == '__main__': diff --git a/train_wr2.py b/train_wr2.py index 99d5d18..63457b0 100644 --- a/train_wr2.py +++ b/train_wr2.py @@ -66,6 +66,9 @@ def adjust_learning_rate(lr, warmup_epoch, optimizer, epoch: int, step: int, len def train(train_root, val_root, batch_size, output, device): + if RANK in {-1, 0} and not os.path.exists(output): + os.makedirs(output) + LOGGER.info("=> Create Model") # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = wR2(num_classes=4).to(device) @@ -96,6 +99,8 @@ def train(train_root, val_root, batch_size, output, device): LOGGER.info("=> Start training") t0 = time.time() + amp = True + scaler = torch.cuda.amp.GradScaler(enabled=amp) # DDP mode cuda = device.type != 'cpu' @@ -119,17 +124,16 @@ def train(train_root, val_root, batch_size, output, device): images = images.to(device) targets = targets.to(device) - outputs = model(images) - - loss = criterion(outputs, targets) - # if RANK != -1: - # loss *= WORLD_SIZE # gradient averaged between devices in DDP mode - loss.backward() + with torch.cuda.amp.autocast(amp): + outputs = model(images) + loss = criterion(outputs, targets) + scaler.scale(loss).backward() if epoch <= warmup_epoch: adjust_learning_rate(learn_rate, warmup_epoch, optimizer, epoch - 1, idx, len(train_dataloader)) - optimizer.step() + scaler.step(optimizer) # optimizer.step + scaler.update() optimizer.zero_grad() if RANK in {-1, 0}: @@ -156,14 +160,11 @@ def train(train_root, val_root, batch_size, output, device): ap, _ = ccpd_evaluator.result() LOGGER.info(f"AP: {ap * 100:.3f}") scheduler.step() + torch.cuda.empty_cache() LOGGER.info(f'\n{epochs} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') def main(opt): - output = opt.output - if not os.path.exists(output): - os.makedirs(output) - # DDP mode device = select_device(opt.device, batch_size=opt.batch_size) if LOCAL_RANK != -1: @@ -177,7 +178,7 @@ def main(opt): init_seeds(opt.seed + 1 + RANK, deterministic=True) # LOGGER.info(f"LOCAL_RANK: {LOCAL_RANK} RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}") - train(opt.train_root, opt.val_root, opt.batch_size, output, device) + train(opt.train_root, opt.val_root, opt.batch_size, opt.output, device) if __name__ == '__main__':