Skip to content

Commit

Permalink
feat(train): add Automatic Mixed Precision (AMP)
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Oct 3, 2023
1 parent e4a9555 commit 6608e7b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
26 changes: 13 additions & 13 deletions train_rpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def adjust_learning_rate(lr, warmup_epoch, optimizer, epoch: int, step: int, len


def train(train_root, val_root, batch_size, output, device, wr2_pretrained):
if RANK in {-1, 0} and not os.path.exists(output):
os.makedirs(output)

LOGGER.info("=> Create Model")
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RPNet(device=device, wr2_pretrained=wr2_pretrained).to(device)
Expand Down Expand Up @@ -97,6 +100,8 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained):

LOGGER.info("=> Start training")
t0 = time.time()
amp = True
scaler = torch.cuda.amp.GradScaler(enabled=amp)

# DDP mode
cuda = device.type != 'cpu'
Expand All @@ -120,17 +125,16 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained):
images = images.to(device)
targets = targets.to(device)

outputs = model(images)

loss = criterion(outputs, targets)
# if RANK != -1:
# loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
loss.backward()
with torch.cuda.amp.autocast(amp):
outputs = model(images)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()

if epoch <= warmup_epoch:
adjust_learning_rate(learn_rate, warmup_epoch, optimizer, epoch - 1, idx, len(train_dataloader))

optimizer.step()
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()

if RANK in {-1, 0}:
Expand All @@ -157,14 +161,11 @@ def train(train_root, val_root, batch_size, output, device, wr2_pretrained):
ap, acc = ccpd_evaluator.result()
LOGGER.info(f"AP:{ap * 100:.3f} ACC: {acc * 100:.3f}")
scheduler.step()
torch.cuda.empty_cache()
LOGGER.info(f'\n{epochs} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')


def main(opt):
output = opt.output
if not os.path.exists(output):
os.makedirs(output)

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if LOCAL_RANK != -1:
Expand All @@ -176,10 +177,9 @@ def main(opt):
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")

# init_seeds(opt.seed + 1 + RANK, deterministic=True)
init_seeds(opt.seed + 1 + RANK, deterministic=False)
# LOGGER.info(f"LOCAL_RANK: {LOCAL_RANK} RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}")
train(opt.train_root, opt.val_root, opt.batch_size, output, device, opt.wr2_pretrained)
train(opt.train_root, opt.val_root, opt.batch_size, opt.output, device, opt.wr2_pretrained)


if __name__ == '__main__':
Expand Down
25 changes: 13 additions & 12 deletions train_wr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def adjust_learning_rate(lr, warmup_epoch, optimizer, epoch: int, step: int, len


def train(train_root, val_root, batch_size, output, device):
if RANK in {-1, 0} and not os.path.exists(output):
os.makedirs(output)

LOGGER.info("=> Create Model")
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = wR2(num_classes=4).to(device)
Expand Down Expand Up @@ -96,6 +99,8 @@ def train(train_root, val_root, batch_size, output, device):

LOGGER.info("=> Start training")
t0 = time.time()
amp = True
scaler = torch.cuda.amp.GradScaler(enabled=amp)

# DDP mode
cuda = device.type != 'cpu'
Expand All @@ -119,17 +124,16 @@ def train(train_root, val_root, batch_size, output, device):
images = images.to(device)
targets = targets.to(device)

outputs = model(images)

loss = criterion(outputs, targets)
# if RANK != -1:
# loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
loss.backward()
with torch.cuda.amp.autocast(amp):
outputs = model(images)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()

if epoch <= warmup_epoch:
adjust_learning_rate(learn_rate, warmup_epoch, optimizer, epoch - 1, idx, len(train_dataloader))

optimizer.step()
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()

if RANK in {-1, 0}:
Expand All @@ -156,14 +160,11 @@ def train(train_root, val_root, batch_size, output, device):
ap, _ = ccpd_evaluator.result()
LOGGER.info(f"AP: {ap * 100:.3f}")
scheduler.step()
torch.cuda.empty_cache()
LOGGER.info(f'\n{epochs} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')


def main(opt):
output = opt.output
if not os.path.exists(output):
os.makedirs(output)

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if LOCAL_RANK != -1:
Expand All @@ -177,7 +178,7 @@ def main(opt):

init_seeds(opt.seed + 1 + RANK, deterministic=True)
# LOGGER.info(f"LOCAL_RANK: {LOCAL_RANK} RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}")
train(opt.train_root, opt.val_root, opt.batch_size, output, device)
train(opt.train_root, opt.val_root, opt.batch_size, opt.output, device)


if __name__ == '__main__':
Expand Down

0 comments on commit 6608e7b

Please sign in to comment.