Skip to content

Commit

Permalink
添加EMA功能集成
Browse files Browse the repository at this point in the history
  • Loading branch information
Puiching-Memory committed Nov 11, 2024
1 parent a8c4dc1 commit c2d8e1f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
2 changes: 1 addition & 1 deletion experiment/monolite_YOLO11_centernet/trainner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class trainner(TrainerBase):
def __init__(self):
self.start_epoch = 0
self.end_epoch = 5
self.end_epoch = 1
self.save_path = r"C:\workspace\github\monolite\experiment\monolite_YOLO11_centernet\checkpoint"
self.resume_checkpoint = None
#self.resume_checkpoint = "C:\workspace\github\monolite\experiment\monolite_YOLO11_centernet\checkpoint\model.pth"
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ albumentations==1.4.21
# log tools (must be installed)
loguru==0.7.2
rich==13.9.4
swanlab==0.3.23
swanlab==0.3.25
#tqdm==4.66.5

# export tools (optional)
Expand All @@ -32,6 +32,9 @@ monkeytype==23.3.0
# optimizer tools (optional)
prodigyopt==1.0

# hyperparameter (optional)
optuna==4.0.0

# doc tools (optional)
mkdocs-material[imaging]==9.5.44
mkdocs-git-revision-date-localized-plugin==1.3.0
Expand Down
47 changes: 37 additions & 10 deletions tools/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
import os

import torch.nn.intrinsic

sys.path.append(os.path.abspath("./"))

from lib.utils.logger import logger, build_progress
Expand Down Expand Up @@ -35,12 +33,13 @@
import swanlab
import datetime
import psutil
from typing import Optional

try:
local_rank = int(os.environ["LOCAL_RANK"])
except:
local_rank = -1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pid = os.getpid()
pcontext = psutil.Process(pid)
Expand All @@ -49,6 +48,7 @@

def train(
model: torch.nn.Module,
ema_model: Optional[torch.nn.Module],
trainner: TrainerBase,
device,
train_loader: torch.utils.data.DataLoader,
Expand Down Expand Up @@ -84,11 +84,12 @@ def train(
model.train()
for i, (inputs, targets, data_info) in enumerate(train_loader):
optimizer.zero_grad()
#inputs = inputs.to(device,memory_format=torch.channels_last)
# inputs = inputs.to(device,memory_format=torch.channels_last)
inputs = inputs.to(device)
targets = {key: value.to(device) for key, value in targets.items()}
with torch.autocast(
device_type="cuda" if torch.cuda.is_available() else "cpu", enabled=trainner.is_amp()
device_type="cuda" if torch.cuda.is_available() else "cpu",
enabled=trainner.is_amp(),
):
forward_time = time.time_ns()
outputs = model(inputs)
Expand All @@ -102,6 +103,9 @@ def train(
scaler.step(optimizer)
scaler.update()

if ema_model is not None:
ema_model.update_parameters(model)

info = {
"epoch": epoch_now,
"micostep": i,
Expand All @@ -115,7 +119,10 @@ def train(
**loss_info,
"cpu(%)": round(pcontext.cpu_percent(), 2),
"ram(%)": round(pcontext.memory_percent(), 2),
**{f"cuda/{k}": v for k, v in torch.cuda.memory_stats(device=device).items()}, # cuda信息
**{
f"cuda/{k}": v
for k, v in torch.cuda.memory_stats(device=device).items()
}, # cuda信息
}
swanlab.log(info)

Expand Down Expand Up @@ -148,6 +155,7 @@ def train(
progress["System"].update(
task_ids["jobId_ram_info"], completed=info["ram(%)"]
)
# break

scheduler.step()

Expand Down Expand Up @@ -189,6 +197,19 @@ def train(
total=trainner.get_end_epoch(),
)

if ema_model is not None:
logger.info("update bn with ema model, it may takes few minutes ...")
torch.optim.swa_utils.update_bn(train_loader, ema_model, device=device)
logger.info(
f"ema checkpoint saved to {os.path.join(trainner.get_save_path(), "model_ema.pth")}"
)
torch.save(
{
"model": ema_model.state_dict(),
},
os.path.join(trainner.get_save_path(), "model_ema.pth"),
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Monolite training script")
Expand Down Expand Up @@ -217,7 +238,7 @@ def train(
# 导入模型
model: torch.nn.Module = importlib.import_module("model").model()
# model = torch.compile(model) # Not support in windows

# 导入数据集
data_set: DataSetBase = importlib.import_module("dataset").data_set()

Expand Down Expand Up @@ -249,20 +270,26 @@ def train(
optimizer.load_state_dict(checkpoint_dict["optimizer"])
scheduler.load_state_dict(checkpoint_dict["scheduler"])
trainner.set_start_epoch(checkpoint_dict["epoch"])
#model = model.to(device,memory_format=torch.channels_last)

# model = model.to(device,memory_format=torch.channels_last)
model = model.to(device)

# Enabel this line to use ema model
# ema_model = torch.optim.swa_utils.AveragedModel(
# model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999)
# )

# 打印基本信息
print(
f"\n{summary(model, input_size=(data_set.get_bath_size(),3,384,1280),mode='train',verbose=0,depth=2)}"
)
logger.info(data_set)
logger.info(optimizer)
logger.info(scheduler)

train(
model,
None,
trainner,
device,
data_set.get_train_loader(),
Expand Down

0 comments on commit c2d8e1f

Please sign in to comment.