Skip to content

Commit

Permalink
Add: Added PSNR and SSIM calculators to verify super resolution image…
Browse files Browse the repository at this point in the history
… quality and optimize super resolution train.py.
  • Loading branch information
chairc committed Nov 26, 2024
1 parent 3302511 commit 399ddea
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 15 deletions.
64 changes: 51 additions & 13 deletions sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from tqdm import tqdm

sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import loss_func_choices, sr_network_choices, optim_choices
from config.choices import sr_network_choices, optim_choices
from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from config.version import get_version_banner
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \
lr_initializer, amp_initializer, loss_initializer
from utils.utils import save_images, setup_logging, save_train_logging, check_and_create_dir
from utils.checkpoint import load_ckpt, save_ckpt
from utils.metrics import compute_psnr, compute_ssim
from sr.interface import post_image
from sr.dataset import get_sr_dataset

Expand Down Expand Up @@ -63,8 +64,8 @@ def train(rank=None, args=None):
num_workers = args.num_workers
# Select optimizer
optim = args.optim
# Select activation function
loss_func = args.loss
# Loss function only mse
loss_func = "mse"
# Select activation function
act = args.act
# Learning rate
Expand Down Expand Up @@ -115,7 +116,9 @@ def train(rank=None, args=None):
# Dataloader
train_dataloader = get_sr_dataset(image_size=image_size, dataset_path=train_dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
val_dataloader = get_sr_dataset(image_size=image_size, dataset_path=val_dataset_path, batch_size=batch_size,
# Quick eval batch size
val_batch_size = batch_size if args.quick_eval else 1
val_dataloader = get_sr_dataset(image_size=image_size, dataset_path=val_dataset_path, batch_size=val_batch_size,
num_workers=num_workers, distributed=distributed)
# Resume training
resume = args.resume
Expand All @@ -141,17 +144,27 @@ def train(rank=None, args=None):
# Parameter 'ckpt_path' is None in the train mode
if ckpt_path is None:
ckpt_path = os.path.join(results_dir, "ckpt_last.pt")
# The best model
ckpt_best_path = os.path.join(results_dir, "ckpt_best.pt")
# Get model state
start_epoch = load_ckpt(ckpt_path=ckpt_path, model=model, device=device, optimizer=optimizer,
is_distributed=distributed)
# Get best ssim and psnr
best_ssim, best_psnr = load_ckpt(ckpt_path=ckpt_best_path, device=device, ckpt_type="sr")
logger.info(msg=f"[{device}]: Successfully load resume model checkpoint.")
logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, "
f"best psnr is {best_psnr}.")
else:
# Pretrain mode
if pretrain:
pretrain_path = args.pretrain_path
load_ckpt(ckpt_path=pretrain_path, model=model, device=device, is_pretrain=pretrain,
is_distributed=distributed)
logger.info(msg=f"[{device}]: Successfully load pretrain model checkpoint.")
start_epoch = 0
# Init
start_epoch, best_ssim, best_psnr = 0, 0, 0
logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, "
f"best psnr is {best_psnr}.")
# Set harf-precision
scaler = amp_initializer(amp=amp, device=device)
# Loss function
Expand Down Expand Up @@ -180,7 +193,7 @@ def train(rank=None, args=None):
save_val_vis_dir = os.path.join(results_vis_dir, str(epoch))
check_and_create_dir(save_val_vis_dir)
# Initialize images and labels
train_loss_list, val_loss_list = [], []
train_loss_list, val_loss_list, ssim_list, psnr_list = [], [], [], []

# Train
model.train()
Expand Down Expand Up @@ -241,6 +254,17 @@ def train(rank=None, args=None):
tb_logger.add_scalar(tag=f"[{device}]: Val loss({loss_func})", scalar_value=val_loss.item(),
global_step=epoch * len_val_dataloader + i)
val_loss_list.append(val_loss.item())

# Metric
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
psnr_res = compute_psnr(mse=val_loss.item())
tb_logger.add_scalar(tag=f"[{device}]: SSIM({loss_func})", scalar_value=ssim_res,
global_step=epoch * len_val_dataloader + i)
tb_logger.add_scalar(tag=f"[{device}]: PSNR({loss_func})", scalar_value=psnr_res,
global_step=epoch * len_val_dataloader + i)
ssim_list.append(ssim_res)
psnr_list.append(psnr_res)

# Save super resolution image and high resolution image
lr_images = post_image(lr_images, device=device)
sr_images = post_image(output, device=device)
Expand All @@ -252,9 +276,14 @@ def train(rank=None, args=None):
save_images(images=sr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{sr_index}_sr.jpg"))
for hr_index, hr_image in enumerate(hr_images):
save_images(images=hr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{hr_index}_hr.jpg"))
# Loss per epoch
tb_logger.add_scalar(tag=f"[{device}]: Val loss", scalar_value=sum(val_loss_list) / len(val_loss_list),
global_step=epoch)
# Loss, ssim and psnr per epoch
avg_val_loss = sum(val_loss_list) / len(val_loss_list)
avg_ssim = sum(ssim_list) / len(ssim_list)
avg_psnr = sum(psnr_list) / len(psnr_list)
tb_logger.add_scalar(tag=f"[{device}]: Val loss", scalar_value=avg_val_loss, global_step=epoch)
tb_logger.add_scalar(tag=f"[{device}]: Avg ssim", scalar_value=avg_ssim, global_step=epoch)
tb_logger.add_scalar(tag=f"[{device}]: Avg psnr", scalar_value=avg_psnr, global_step=epoch)
logger.info(f"Val loss: {avg_val_loss}, SSIM: {avg_ssim}, PSNR: {avg_psnr}")
logger.info(msg="Finish val mode.")

# Saving and validating models in the main process
Expand All @@ -265,10 +294,18 @@ def train(rank=None, args=None):
ckpt_model = model.state_dict()
ckpt_ema_model = ema_model.state_dict()
ckpt_optimizer = optimizer.state_dict()
# Save the best model
if (avg_ssim > best_ssim) and (avg_psnr > best_psnr):
is_best = True
best_ssim = avg_ssim
best_psnr = avg_psnr
else:
is_best = False
# Save checkpoint
save_ckpt(epoch=epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model,
ckpt_optimizer=ckpt_optimizer, results_dir=results_dir, save_model_interval=save_model_interval,
start_model_interval=start_model_interval, image_size=image_size, network=network, act=act)
save_model_interval_epochs=None, start_model_interval=start_model_interval, image_size=image_size,
network=network, act=act, is_sr=True, is_best=is_best, ssim=avg_ssim, psnr=avg_psnr)
logger.info(msg=f"[{device}]: Finish epoch {epoch}:")

# Synchronization during distributed training
Expand Down Expand Up @@ -331,12 +368,13 @@ def main(args):
# Enable automatic mixed precision training (needed)
# Effectively reducing GPU memory usage may lead to lower training accuracy and results.
parser.add_argument("--amp", default=False, action="store_true")
# It is recommended that the batch size of the evaluation data be set to 1,
# which can accurately evaluate each picture and reduce the evaluation error of each group of pictures.
# If you want to evaluate quickly, set it to batch size.
parser.add_argument("--quick_eval", default=False, action="store_true")
# Set optimizer (needed)
# Option: adam/adamw/sgd
parser.add_argument("--optim", type=str, default="sgd", choices=optim_choices)
# Set loss function (needed)
# Option: mse/l1/huber/smooth_l1
parser.add_argument("--loss", type=str, default="mse", choices=loss_func_choices)
# Set activation function (needed)
# Option: gelu/silu/relu/relu6/lrelu
parser.add_argument("--act", type=str, default="silu")
Expand Down
23 changes: 21 additions & 2 deletions utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from collections import OrderedDict

from utils.check import check_path_is_exist

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def load_ckpt(ckpt_path, model, device, optimizer=None, is_train=True, is_pretrain=False, is_distributed=False,
is_use_ema=False, conditional=False):
def load_ckpt(ckpt_path, model=None, device="cpu", optimizer=None, is_train=True, is_pretrain=False,
is_distributed=False, is_use_ema=False, conditional=False, ckpt_type="df"):
"""
Load checkpoint weight files
:param ckpt_path: Checkpoint path
Expand All @@ -31,10 +33,17 @@ def load_ckpt(ckpt_path, model, device, optimizer=None, is_train=True, is_pretra
:param is_distributed: Whether to distribute training
:param is_use_ema: Whether to use ema model or model
:param conditional: Whether conditional training
:param ckpt_type: Type of checkpoint
:return: start_epoch + 1
"""
# Check path
check_path_is_exist(path=ckpt_path)
# Load checkpoint
ckpt_state = torch.load(f=ckpt_path, map_location=device)
# Load the best model params as ssim and psnr
if ckpt_type == "sr":
logger.info(msg=f"[{device}]: Successfully load the best sr checkpoint from {ckpt_path}.")
return ckpt_state["ssim"], ckpt_state["psnr"]
if is_pretrain:
logger.info(msg=f"[{device}]: Successfully load pretrain checkpoint, path is '{ckpt_path}'.")
else:
Expand Down Expand Up @@ -148,6 +157,16 @@ def save_ckpt(epoch, save_name, ckpt_model, ckpt_ema_model, ckpt_optimizer, resu
"num_classes": num_classes if conditional else 1, "classes_name": classes_name, "conditional": conditional,
"image_size": image_size, "sample": sample, "network": network, "act": act,
}
# Check is sr mode
if kwargs.get("is_sr", False):
best_ssim, best_psnr = kwargs.get("ssim"), kwargs.get("psnr")
ckpt_state["ssim"] = best_ssim
ckpt_state["psnr"] = best_psnr
# Check is the best sr model?
if kwargs.get("is_best", False):
last_filename = os.path.join(results_dir, f"ckpt_best.pt")
torch.save(obj=ckpt_state, f=last_filename)
logger.info(msg=f"Save the ckpt_best.pt, best ssim is {best_ssim}, best psnr is {best_psnr}")
# Save last checkpoint, it must be done
last_filename = os.path.join(results_dir, f"ckpt_last.pt")
torch.save(obj=ckpt_state, f=last_filename)
Expand Down
39 changes: 39 additions & 0 deletions utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/11/3 17:52
@Author : chairc
@Site : https://github.com/chairc
"""
import math

from skimage.metrics import structural_similarity


def compute_psnr(mse):
"""
PSNR
"""
# Results
if mse == 0:
return 100
else:
return 20 * math.log10(255.0 / math.sqrt(mse))


def compute_ssim(image_outputs, image_sources):
"""
SSIM
"""
# Transfer to numpy
image_outputs = image_outputs.to("cpu").numpy()
image_sources = image_sources.to("cpu").numpy()
ssim_list = []
if image_outputs.shape != image_sources.shape or image_outputs.shape[0] != image_sources.shape[0]:
raise AssertionError("Image outputs and image sources shape mismatch.")
# image_outputs.shape[0] and image_sources.shape[0] are equal
length = image_outputs.shape[0]
for i in range(length):
ssim = structural_similarity(image_outputs[i], image_sources[i], channel_axis=0, data_range=255)
ssim_list.append(ssim)
return sum(ssim_list) / length

0 comments on commit 399ddea

Please sign in to comment.