Skip to content

Commit

Permalink
Update: Modify the method import path and eliminate the magic value.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Jul 8, 2024
1 parent b36719a commit 315dcb8
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 20 deletions.
3 changes: 2 additions & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
@Site : https://github.com/chairc
"""
from .choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, noise_schedule_choices, RANDOM_RESIZED_CROP_SCALE, MEAN, STD
image_format_choices, noise_schedule_choices
from .setting import MASTER_ADDR, MASTER_PORT, EMA_BETA, RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest
9 changes: 0 additions & 9 deletions config/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@
image_format_choices = ["png", "jpg", "jpeg", "webp", "tif"]
noise_schedule_choices = ["linear", "cosine", "sqrt_linear", "sqrt"]

# Some special parameter settings
# ****** torchvision.transforms.Compose ******
# RandomResizedCrop
RANDOM_RESIZED_CROP_SCALE = (0.8, 1.0)
# Mean in datasets
MEAN = (0.485, 0.456, 0.406)
# Std in datasets
STD = (0.229, 0.224, 0.225)


# Function
def parse_image_size_type(image_size_str):
Expand Down
2 changes: 1 addition & 1 deletion sr/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from PIL import Image
from torch.utils.data import Dataset, DataLoader, DistributedSampler

from config.choices import MEAN, STD
from config.setting import MEAN, STD


class SRDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion sr/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torchvision
import logging

from config.choices import MEAN, STD
from config.setting import MEAN, STD
from utils.checkpoint import load_ckpt
from utils.initializer import sr_network_initializer, device_initializer
from utils.utils import check_and_create_dir
Expand Down
7 changes: 4 additions & 3 deletions sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tqdm import tqdm

sys.path.append(os.path.dirname(sys.path[0]))
from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \
lr_initializer, amp_initializer, loss_initializer
Expand Down Expand Up @@ -76,8 +77,8 @@ def train(rank=None, args=None):
distributed = True
world_size = args.world_size
# Set address and port
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12346"
os.environ["MASTER_ADDR"] = MASTER_ADDR
os.environ["MASTER_PORT"] = MASTER_PORT
# The total number of processes is equal to the number of graphics cards
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank,
world_size=world_size)
Expand Down Expand Up @@ -161,7 +162,7 @@ def train(rank=None, args=None):
len_train_dataloader = len(train_dataloader)
len_val_dataloader = len(val_dataloader)
# Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
ema = EMA(beta=0.995)
ema = EMA(beta=EMA_BETA)
# EMA model
ema_model = copy.deepcopy(model).eval().requires_grad_(False)

Expand Down
7 changes: 4 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, noise_schedule_choices, parse_image_size_type
from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from model.modules.ema import EMA
from utils.check import check_image_size
from utils.dataset import get_dataset
Expand Down Expand Up @@ -98,8 +99,8 @@ def train(rank=None, args=None):
distributed = True
world_size = args.world_size
# Set address and port
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
os.environ["MASTER_ADDR"] = MASTER_ADDR
os.environ["MASTER_PORT"] = MASTER_PORT
# The total number of processes is equal to the number of graphics cards
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank,
world_size=world_size)
Expand Down Expand Up @@ -182,7 +183,7 @@ def train(rank=None, args=None):
# Initialize the diffusion model
diffusion = sample_initializer(sample=sample, image_size=image_size, device=device, schedule_name=noise_schedule)
# Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
ema = EMA(beta=0.995)
ema = EMA(beta=EMA_BETA)
# EMA model
ema_model = copy.deepcopy(model).eval().requires_grad_(False)

Expand Down
2 changes: 1 addition & 1 deletion utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from typing import Union

from config.choices import RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from config.setting import RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from utils.check import check_path_is_exist


Expand Down
3 changes: 2 additions & 1 deletion webui/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, RANDOM_RESIZED_CROP_SCALE, MEAN, STD
image_format_choices
from config.setting import RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
sample_initializer, lr_initializer, amp_initializer, classes_initializer
Expand Down

0 comments on commit 315dcb8

Please sign in to comment.