diff --git a/sr/train.py b/sr/train.py index 3e8f783..4605694 100644 --- a/sr/train.py +++ b/sr/train.py @@ -23,7 +23,7 @@ from tqdm import tqdm sys.path.append(os.path.dirname(sys.path[0])) -from config.choices import loss_func_choices, sr_network_choices, optim_choices +from config.choices import sr_network_choices, optim_choices from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA from config.version import get_version_banner from model.modules.ema import EMA @@ -31,6 +31,7 @@ lr_initializer, amp_initializer, loss_initializer from utils.utils import save_images, setup_logging, save_train_logging, check_and_create_dir from utils.checkpoint import load_ckpt, save_ckpt +from utils.metrics import compute_psnr, compute_ssim from sr.interface import post_image from sr.dataset import get_sr_dataset @@ -63,8 +64,8 @@ def train(rank=None, args=None): num_workers = args.num_workers # Select optimizer optim = args.optim - # Select activation function - loss_func = args.loss + # Loss function only mse + loss_func = "mse" # Select activation function act = args.act # Learning rate @@ -115,7 +116,9 @@ def train(rank=None, args=None): # Dataloader train_dataloader = get_sr_dataset(image_size=image_size, dataset_path=train_dataset_path, batch_size=batch_size, num_workers=num_workers, distributed=distributed) - val_dataloader = get_sr_dataset(image_size=image_size, dataset_path=val_dataset_path, batch_size=batch_size, + # Quick eval batch size + val_batch_size = batch_size if args.quick_eval else 1 + val_dataloader = get_sr_dataset(image_size=image_size, dataset_path=val_dataset_path, batch_size=val_batch_size, num_workers=num_workers, distributed=distributed) # Resume training resume = args.resume @@ -141,9 +144,16 @@ def train(rank=None, args=None): # Parameter 'ckpt_path' is None in the train mode if ckpt_path is None: ckpt_path = os.path.join(results_dir, "ckpt_last.pt") + # The best model + ckpt_best_path = os.path.join(results_dir, "ckpt_best.pt") + # Get model state start_epoch = load_ckpt(ckpt_path=ckpt_path, model=model, device=device, optimizer=optimizer, is_distributed=distributed) + # Get best ssim and psnr + best_ssim, best_psnr = load_ckpt(ckpt_path=ckpt_best_path, device=device, ckpt_type="sr") logger.info(msg=f"[{device}]: Successfully load resume model checkpoint.") + logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, " + f"best psnr is {best_psnr}.") else: # Pretrain mode if pretrain: @@ -151,7 +161,10 @@ def train(rank=None, args=None): load_ckpt(ckpt_path=pretrain_path, model=model, device=device, is_pretrain=pretrain, is_distributed=distributed) logger.info(msg=f"[{device}]: Successfully load pretrain model checkpoint.") - start_epoch = 0 + # Init + start_epoch, best_ssim, best_psnr = 0, 0, 0 + logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, " + f"best psnr is {best_psnr}.") # Set harf-precision scaler = amp_initializer(amp=amp, device=device) # Loss function @@ -180,7 +193,7 @@ def train(rank=None, args=None): save_val_vis_dir = os.path.join(results_vis_dir, str(epoch)) check_and_create_dir(save_val_vis_dir) # Initialize images and labels - train_loss_list, val_loss_list = [], [] + train_loss_list, val_loss_list, ssim_list, psnr_list = [], [], [], [] # Train model.train() @@ -241,6 +254,17 @@ def train(rank=None, args=None): tb_logger.add_scalar(tag=f"[{device}]: Val loss({loss_func})", scalar_value=val_loss.item(), global_step=epoch * len_val_dataloader + i) val_loss_list.append(val_loss.item()) + + # Metric + ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images) + psnr_res = compute_psnr(mse=val_loss.item()) + tb_logger.add_scalar(tag=f"[{device}]: SSIM({loss_func})", scalar_value=ssim_res, + global_step=epoch * len_val_dataloader + i) + tb_logger.add_scalar(tag=f"[{device}]: PSNR({loss_func})", scalar_value=psnr_res, + global_step=epoch * len_val_dataloader + i) + ssim_list.append(ssim_res) + psnr_list.append(psnr_res) + # Save super resolution image and high resolution image lr_images = post_image(lr_images, device=device) sr_images = post_image(output, device=device) @@ -252,9 +276,14 @@ def train(rank=None, args=None): save_images(images=sr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{sr_index}_sr.jpg")) for hr_index, hr_image in enumerate(hr_images): save_images(images=hr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{hr_index}_hr.jpg")) - # Loss per epoch - tb_logger.add_scalar(tag=f"[{device}]: Val loss", scalar_value=sum(val_loss_list) / len(val_loss_list), - global_step=epoch) + # Loss, ssim and psnr per epoch + avg_val_loss = sum(val_loss_list) / len(val_loss_list) + avg_ssim = sum(ssim_list) / len(ssim_list) + avg_psnr = sum(psnr_list) / len(psnr_list) + tb_logger.add_scalar(tag=f"[{device}]: Val loss", scalar_value=avg_val_loss, global_step=epoch) + tb_logger.add_scalar(tag=f"[{device}]: Avg ssim", scalar_value=avg_ssim, global_step=epoch) + tb_logger.add_scalar(tag=f"[{device}]: Avg psnr", scalar_value=avg_psnr, global_step=epoch) + logger.info(f"Val loss: {avg_val_loss}, SSIM: {avg_ssim}, PSNR: {avg_psnr}") logger.info(msg="Finish val mode.") # Saving and validating models in the main process @@ -265,10 +294,18 @@ def train(rank=None, args=None): ckpt_model = model.state_dict() ckpt_ema_model = ema_model.state_dict() ckpt_optimizer = optimizer.state_dict() + # Save the best model + if (avg_ssim > best_ssim) and (avg_psnr > best_psnr): + is_best = True + best_ssim = avg_ssim + best_psnr = avg_psnr + else: + is_best = False # Save checkpoint save_ckpt(epoch=epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model, ckpt_optimizer=ckpt_optimizer, results_dir=results_dir, save_model_interval=save_model_interval, - start_model_interval=start_model_interval, image_size=image_size, network=network, act=act) + save_model_interval_epochs=None, start_model_interval=start_model_interval, image_size=image_size, + network=network, act=act, is_sr=True, is_best=is_best, ssim=avg_ssim, psnr=avg_psnr) logger.info(msg=f"[{device}]: Finish epoch {epoch}:") # Synchronization during distributed training @@ -331,12 +368,13 @@ def main(args): # Enable automatic mixed precision training (needed) # Effectively reducing GPU memory usage may lead to lower training accuracy and results. parser.add_argument("--amp", default=False, action="store_true") + # It is recommended that the batch size of the evaluation data be set to 1, + # which can accurately evaluate each picture and reduce the evaluation error of each group of pictures. + # If you want to evaluate quickly, set it to batch size. + parser.add_argument("--quick_eval", default=False, action="store_true") # Set optimizer (needed) # Option: adam/adamw/sgd parser.add_argument("--optim", type=str, default="sgd", choices=optim_choices) - # Set loss function (needed) - # Option: mse/l1/huber/smooth_l1 - parser.add_argument("--loss", type=str, default="mse", choices=loss_func_choices) # Set activation function (needed) # Option: gelu/silu/relu/relu6/lrelu parser.add_argument("--act", type=str, default="silu") diff --git a/utils/checkpoint.py b/utils/checkpoint.py index c7652db..f047fba 100644 --- a/utils/checkpoint.py +++ b/utils/checkpoint.py @@ -14,12 +14,14 @@ from collections import OrderedDict +from utils.check import check_path_is_exist + logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") -def load_ckpt(ckpt_path, model, device, optimizer=None, is_train=True, is_pretrain=False, is_distributed=False, - is_use_ema=False, conditional=False): +def load_ckpt(ckpt_path, model=None, device="cpu", optimizer=None, is_train=True, is_pretrain=False, + is_distributed=False, is_use_ema=False, conditional=False, ckpt_type="df"): """ Load checkpoint weight files :param ckpt_path: Checkpoint path @@ -31,10 +33,17 @@ def load_ckpt(ckpt_path, model, device, optimizer=None, is_train=True, is_pretra :param is_distributed: Whether to distribute training :param is_use_ema: Whether to use ema model or model :param conditional: Whether conditional training + :param ckpt_type: Type of checkpoint :return: start_epoch + 1 """ + # Check path + check_path_is_exist(path=ckpt_path) # Load checkpoint ckpt_state = torch.load(f=ckpt_path, map_location=device) + # Load the best model params as ssim and psnr + if ckpt_type == "sr": + logger.info(msg=f"[{device}]: Successfully load the best sr checkpoint from {ckpt_path}.") + return ckpt_state["ssim"], ckpt_state["psnr"] if is_pretrain: logger.info(msg=f"[{device}]: Successfully load pretrain checkpoint, path is '{ckpt_path}'.") else: @@ -148,6 +157,16 @@ def save_ckpt(epoch, save_name, ckpt_model, ckpt_ema_model, ckpt_optimizer, resu "num_classes": num_classes if conditional else 1, "classes_name": classes_name, "conditional": conditional, "image_size": image_size, "sample": sample, "network": network, "act": act, } + # Check is sr mode + if kwargs.get("is_sr", False): + best_ssim, best_psnr = kwargs.get("ssim"), kwargs.get("psnr") + ckpt_state["ssim"] = best_ssim + ckpt_state["psnr"] = best_psnr + # Check is the best sr model? + if kwargs.get("is_best", False): + last_filename = os.path.join(results_dir, f"ckpt_best.pt") + torch.save(obj=ckpt_state, f=last_filename) + logger.info(msg=f"Save the ckpt_best.pt, best ssim is {best_ssim}, best psnr is {best_psnr}") # Save last checkpoint, it must be done last_filename = os.path.join(results_dir, f"ckpt_last.pt") torch.save(obj=ckpt_state, f=last_filename) diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..ee254ae --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/11/3 17:52 + @Author : chairc + @Site : https://github.com/chairc +""" +import math + +from skimage.metrics import structural_similarity + + +def compute_psnr(mse): + """ + PSNR + """ + # Results + if mse == 0: + return 100 + else: + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def compute_ssim(image_outputs, image_sources): + """ + SSIM + """ + # Transfer to numpy + image_outputs = image_outputs.to("cpu").numpy() + image_sources = image_sources.to("cpu").numpy() + ssim_list = [] + if image_outputs.shape != image_sources.shape or image_outputs.shape[0] != image_sources.shape[0]: + raise AssertionError("Image outputs and image sources shape mismatch.") + # image_outputs.shape[0] and image_sources.shape[0] are equal + length = image_outputs.shape[0] + for i in range(length): + ssim = structural_similarity(image_outputs[i], image_sources[i], channel_axis=0, data_range=255) + ssim_list.append(ssim) + return sum(ssim_list) / length