From c2d8e1fbc0ecf037986568878bff7fe509122ad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=A6=E5=BD=92=E4=BA=91=E5=B8=86?= <1138663075@qq.com> Date: Mon, 11 Nov 2024 23:17:16 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0EMA=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E9=9B=86=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../monolite_YOLO11_centernet/trainner.py | 2 +- requirements.txt | 5 +- tools/train.py | 47 +++++++++++++++---- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/experiment/monolite_YOLO11_centernet/trainner.py b/experiment/monolite_YOLO11_centernet/trainner.py index fd55dff..60a5177 100644 --- a/experiment/monolite_YOLO11_centernet/trainner.py +++ b/experiment/monolite_YOLO11_centernet/trainner.py @@ -13,7 +13,7 @@ class trainner(TrainerBase): def __init__(self): self.start_epoch = 0 - self.end_epoch = 5 + self.end_epoch = 1 self.save_path = r"C:\workspace\github\monolite\experiment\monolite_YOLO11_centernet\checkpoint" self.resume_checkpoint = None #self.resume_checkpoint = "C:\workspace\github\monolite\experiment\monolite_YOLO11_centernet\checkpoint\model.pth" diff --git a/requirements.txt b/requirements.txt index 8bcf16e..78dffa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ albumentations==1.4.21 # log tools (must be installed) loguru==0.7.2 rich==13.9.4 -swanlab==0.3.23 +swanlab==0.3.25 #tqdm==4.66.5 # export tools (optional) @@ -32,6 +32,9 @@ monkeytype==23.3.0 # optimizer tools (optional) prodigyopt==1.0 +# hyperparameter (optional) +optuna==4.0.0 + # doc tools (optional) mkdocs-material[imaging]==9.5.44 mkdocs-git-revision-date-localized-plugin==1.3.0 diff --git a/tools/train.py b/tools/train.py index 1c9647d..ed2f1e2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,8 +1,6 @@ import sys import os -import torch.nn.intrinsic - sys.path.append(os.path.abspath("./")) from lib.utils.logger import logger, build_progress @@ -35,12 +33,13 @@ import swanlab import datetime import psutil +from typing import Optional try: local_rank = int(os.environ["LOCAL_RANK"]) except: local_rank = -1 - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pid = os.getpid() pcontext = psutil.Process(pid) @@ -49,6 +48,7 @@ def train( model: torch.nn.Module, + ema_model: Optional[torch.nn.Module], trainner: TrainerBase, device, train_loader: torch.utils.data.DataLoader, @@ -84,11 +84,12 @@ def train( model.train() for i, (inputs, targets, data_info) in enumerate(train_loader): optimizer.zero_grad() - #inputs = inputs.to(device,memory_format=torch.channels_last) + # inputs = inputs.to(device,memory_format=torch.channels_last) inputs = inputs.to(device) targets = {key: value.to(device) for key, value in targets.items()} with torch.autocast( - device_type="cuda" if torch.cuda.is_available() else "cpu", enabled=trainner.is_amp() + device_type="cuda" if torch.cuda.is_available() else "cpu", + enabled=trainner.is_amp(), ): forward_time = time.time_ns() outputs = model(inputs) @@ -102,6 +103,9 @@ def train( scaler.step(optimizer) scaler.update() + if ema_model is not None: + ema_model.update_parameters(model) + info = { "epoch": epoch_now, "micostep": i, @@ -115,7 +119,10 @@ def train( **loss_info, "cpu(%)": round(pcontext.cpu_percent(), 2), "ram(%)": round(pcontext.memory_percent(), 2), - **{f"cuda/{k}": v for k, v in torch.cuda.memory_stats(device=device).items()}, # cuda信息 + **{ + f"cuda/{k}": v + for k, v in torch.cuda.memory_stats(device=device).items() + }, # cuda信息 } swanlab.log(info) @@ -148,6 +155,7 @@ def train( progress["System"].update( task_ids["jobId_ram_info"], completed=info["ram(%)"] ) + # break scheduler.step() @@ -189,6 +197,19 @@ def train( total=trainner.get_end_epoch(), ) + if ema_model is not None: + logger.info("update bn with ema model, it may takes few minutes ...") + torch.optim.swa_utils.update_bn(train_loader, ema_model, device=device) + logger.info( + f"ema checkpoint saved to {os.path.join(trainner.get_save_path(), "model_ema.pth")}" + ) + torch.save( + { + "model": ema_model.state_dict(), + }, + os.path.join(trainner.get_save_path(), "model_ema.pth"), + ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Monolite training script") @@ -217,7 +238,7 @@ def train( # 导入模型 model: torch.nn.Module = importlib.import_module("model").model() # model = torch.compile(model) # Not support in windows - + # 导入数据集 data_set: DataSetBase = importlib.import_module("dataset").data_set() @@ -249,10 +270,15 @@ def train( optimizer.load_state_dict(checkpoint_dict["optimizer"]) scheduler.load_state_dict(checkpoint_dict["scheduler"]) trainner.set_start_epoch(checkpoint_dict["epoch"]) - - #model = model.to(device,memory_format=torch.channels_last) + + # model = model.to(device,memory_format=torch.channels_last) model = model.to(device) + # Enabel this line to use ema model + # ema_model = torch.optim.swa_utils.AveragedModel( + # model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999) + # ) + # 打印基本信息 print( f"\n{summary(model, input_size=(data_set.get_bath_size(),3,384,1280),mode='train',verbose=0,depth=2)}" @@ -260,9 +286,10 @@ def train( logger.info(data_set) logger.info(optimizer) logger.info(scheduler) - + train( model, + None, trainner, device, data_set.get_train_loader(),