From 315dcb8088c7cf860213bc57733c7d0fb2f9cb75 Mon Sep 17 00:00:00 2001 From: chairc <974833488@qq.com> Date: Tue, 9 Jul 2024 00:22:20 +0800 Subject: [PATCH] Update: Modify the method import path and eliminate the magic value. --- config/__init__.py | 3 ++- config/choices.py | 9 --------- sr/dataset.py | 2 +- sr/interface.py | 2 +- sr/train.py | 7 ++++--- tools/train.py | 7 ++++--- utils/dataset.py | 2 +- webui/web.py | 3 ++- 8 files changed, 15 insertions(+), 20 deletions(-) diff --git a/config/__init__.py b/config/__init__.py index 24a0514..3ae6e8f 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -6,5 +6,6 @@ @Site : https://github.com/chairc """ from .choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \ - image_format_choices, noise_schedule_choices, RANDOM_RESIZED_CROP_SCALE, MEAN, STD + image_format_choices, noise_schedule_choices +from .setting import MASTER_ADDR, MASTER_PORT, EMA_BETA, RANDOM_RESIZED_CROP_SCALE, MEAN, STD from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest diff --git a/config/choices.py b/config/choices.py index 5ab4354..b34b2b3 100644 --- a/config/choices.py +++ b/config/choices.py @@ -24,15 +24,6 @@ image_format_choices = ["png", "jpg", "jpeg", "webp", "tif"] noise_schedule_choices = ["linear", "cosine", "sqrt_linear", "sqrt"] -# Some special parameter settings -# ****** torchvision.transforms.Compose ****** -# RandomResizedCrop -RANDOM_RESIZED_CROP_SCALE = (0.8, 1.0) -# Mean in datasets -MEAN = (0.485, 0.456, 0.406) -# Std in datasets -STD = (0.229, 0.224, 0.225) - # Function def parse_image_size_type(image_size_str): diff --git a/sr/dataset.py b/sr/dataset.py index 96c4a41..6900172 100644 --- a/sr/dataset.py +++ b/sr/dataset.py @@ -12,7 +12,7 @@ from PIL import Image from torch.utils.data import Dataset, DataLoader, DistributedSampler -from config.choices import MEAN, STD +from config.setting import MEAN, STD class SRDataset(Dataset): diff --git a/sr/interface.py b/sr/interface.py index 8e320a6..07928fc 100644 --- a/sr/interface.py +++ b/sr/interface.py @@ -13,7 +13,7 @@ import torchvision import logging -from config.choices import MEAN, STD +from config.setting import MEAN, STD from utils.checkpoint import load_ckpt from utils.initializer import sr_network_initializer, device_initializer from utils.utils import check_and_create_dir diff --git a/sr/train.py b/sr/train.py index 600d26f..34760af 100644 --- a/sr/train.py +++ b/sr/train.py @@ -23,6 +23,7 @@ from tqdm import tqdm sys.path.append(os.path.dirname(sys.path[0])) +from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \ lr_initializer, amp_initializer, loss_initializer @@ -76,8 +77,8 @@ def train(rank=None, args=None): distributed = True world_size = args.world_size # Set address and port - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12346" + os.environ["MASTER_ADDR"] = MASTER_ADDR + os.environ["MASTER_PORT"] = MASTER_PORT # The total number of processes is equal to the number of graphics cards dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank, world_size=world_size) @@ -161,7 +162,7 @@ def train(rank=None, args=None): len_train_dataloader = len(train_dataloader) len_val_dataloader = len(val_dataloader) # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class - ema = EMA(beta=0.995) + ema = EMA(beta=EMA_BETA) # EMA model ema_model = copy.deepcopy(model).eval().requires_grad_(False) diff --git a/tools/train.py b/tools/train.py index 4d5b976..43d02bd 100644 --- a/tools/train.py +++ b/tools/train.py @@ -24,6 +24,7 @@ sys.path.append(os.path.dirname(sys.path[0])) from config.choices import sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \ image_format_choices, noise_schedule_choices, parse_image_size_type +from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA from model.modules.ema import EMA from utils.check import check_image_size from utils.dataset import get_dataset @@ -98,8 +99,8 @@ def train(rank=None, args=None): distributed = True world_size = args.world_size # Set address and port - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12345" + os.environ["MASTER_ADDR"] = MASTER_ADDR + os.environ["MASTER_PORT"] = MASTER_PORT # The total number of processes is equal to the number of graphics cards dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank, world_size=world_size) @@ -182,7 +183,7 @@ def train(rank=None, args=None): # Initialize the diffusion model diffusion = sample_initializer(sample=sample, image_size=image_size, device=device, schedule_name=noise_schedule) # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class - ema = EMA(beta=0.995) + ema = EMA(beta=EMA_BETA) # EMA model ema_model = copy.deepcopy(model).eval().requires_grad_(False) diff --git a/utils/dataset.py b/utils/dataset.py index a5fdca4..3c6b286 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler from typing import Union -from config.choices import RANDOM_RESIZED_CROP_SCALE, MEAN, STD +from config.setting import RANDOM_RESIZED_CROP_SCALE, MEAN, STD from utils.check import check_path_is_exist diff --git a/webui/web.py b/webui/web.py index e3e0bdd..269d75f 100644 --- a/webui/web.py +++ b/webui/web.py @@ -24,7 +24,8 @@ sys.path.append(os.path.dirname(sys.path[0])) from config.choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \ - image_format_choices, RANDOM_RESIZED_CROP_SCALE, MEAN, STD + image_format_choices +from config.setting import RANDOM_RESIZED_CROP_SCALE, MEAN, STD from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ sample_initializer, lr_initializer, amp_initializer, classes_initializer