diff --git a/README.md b/README.md
index afec445..a90a009 100644
--- a/README.md
+++ b/README.md
@@ -28,17 +28,25 @@
+
+
+
+
+
+
[中文文档](README_zh.md)
-### About the Model
+
+
+### 0 About the Model
This diffusion model is based on the classic DDPM (Denoising Diffusion Probabilistic Models), DDIM (Denoising Diffusion Implicit Models) and PLMS (Pseudo Numerical Methods for Diffusion Models on Manifolds) presented in the papers "[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)", "[Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)" and "[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://openreview.net/forum?id=PlKWVd2yBkY)".
We named this project IDDM: Integrated Design Diffusion Model. It aims to reproduce the model, write trainers and generators, and improve and optimize certain algorithms and network structures. This repository is **actively maintained**.
-If you have any questions, please check [**the existing issues**](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9) first. If the issue persists, feel free to open a new one for assistance, or you can contact me via email at chenyu1998424@gmail.com or chairc1998@163.com.
+If you have any questions, please check [**the existing issues**](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9) first. If the issue persists, feel free to open a new one for assistance, or you can contact me via email at chenyu1998424@gmail.com or chairc1998@163.com. **If you think my project is interesting, please give me a ⭐⭐⭐Star⭐⭐⭐ :)**
**Repository Structure**
@@ -46,6 +54,7 @@ If you have any questions, please check [**the existing issues**](https://github
Integrated Design Diffusion Model
├── config
│ ├── choices.py
+│ ├── model_list.py
│ ├── setting.py
│ └── version.py
├── datasets
@@ -70,11 +79,15 @@ Integrated Design Diffusion Model
│ │ ├── base.py
│ │ ├── cspdarkunet.py
│ │ └── unet.py
-│ └── samples
+│ ├── samples
+│ │ ├── base.py
+│ │ ├── ddim.py
+│ │ ├── ddpm.py
+│ │ └── plms.py
+│ └── trainers
│ ├── base.py
-│ ├── ddim.py
-│ ├── ddpm.py
-│ └── plms.py
+│ ├── dm.py
+│ └── sr.py
├── results
├── sr
│ ├── dataset.py
@@ -98,37 +111,49 @@ Integrated Design Diffusion Model
│ ├── initializer.py
│ ├── logger.py
│ ├── lr_scheduler.py
+│ ├── metrics.py
│ ├── processing.py
│ └── utils.py
├── webui
│ └──web.py
-└── weight
+└── weights
```
-### Next Steps
-- [x] 1. Implement cosine learning rate optimization. (2023-07-31)
-- [x] 2. Use a more advanced U-Net network model. (2023-11-09)
-- [x] 3. Generate larger-sized images. (2023-11-09)
-- [x] 4. Implement multi-GPU distributed training. (2023-07-15)
-- [x] 5. Enable fast deployment and API on cloud servers. (2023-08-28)
-- [x] 6. Adding DDIM Sampling Method. (2023-08-03)
-- [x] 7. Support other image generation. (2023-09-16)
-- [x] 8. Low-resolution generated images for super-resolution enhancement.[~~Super resolution model, the effect is uncertain.~~] (2024-02-18)
-- [ ] 9. Use Latent Diffusion and reduce GPU memory usage
-- [x] 10. Reconstruct the overall structure of the model (2023-12-06)
-- [x] 11. Write visual webui interface. (2024-01-23)
-- [x] 12. Adding PLMS Sampling Method. (2024-03-12)
-- [x] 13. Adding FID calculator to verify image quality. (2024-05-06)
-- [x] 14. Adding the deployment of image-generating Sockets and Web server. (2024-11-13)
-### Training
+### 1 Next Steps
+
+- [x] [2023-07-15] Adding implement multi-GPU distributed training.
+- [x] [2023-07-31] Adding implement cosine learning rate optimization.
+- [x] [2023-08-03] Adding DDIM Sampling Method.
+- [x] [2023-08-28] Adding fast deployment and API on cloud servers.
+- [x] [2023-09-16] Support other image generation.
+- [x] [2023-11-09] Adding a more advanced U-Net network model.
+- [x] [2023-11-09] Support generate larger-sized images.
+- [x] [2023-12-06] Reconstruct the overall structure of the model.
+- [x] [2024-01-23] Adding visual webui training interface.
+- [x] [2024-02-18] Support low-resolution generated images for super-resolution enhancement.[~~Super resolution model, the effect is uncertain~~]
+- [x] [2024-03-12] Adding PLMS Sampling Method.
+- [x] [2024-05-06] Adding FID calculator to verify image quality.
+- [x] [2024-06-11] Adding visual webui generate interface.
+- [x] [2024-07-07] Support custom images length and width input.
+- [x] [2024-11-13] Adding the deployment of image-generating Sockets and Web server.
+- [x] [2024-11-26] Adding PSNR and SSIM calculators to verify super resolution image quality.
+- [x] [2024-12-10] Adding pretrain model download.
+- [x] [2024-12-25] Reconstruct the overall structure of the trainer.
+- [ ] [Maybe 2025-01-31] Adding the deployment of Docker and image.
+- [ ] [To be determined] Reconstruct the project by Baidu PaddlePaddle.
+- [ ] [To be determined] ~~Use Latent Diffusion and reduce GPU memory usage~~
+
-#### Note
-The training GPU implements environment for this README is as follows: models are trained and tested with the NVIDIA RTX 3060 GPU with 6GB memory, NVIDIA RTX 2080Ti GPU with 11GB memory and NVIDIA RTX 6000 (×2) GPU with 24GB (total 48GB, distributed training) memory. **The above GPUs can all be trained normally**.
+### 2 Training
-#### Start Your First Training (Using cifar10 as an Example, Single GPU Mode)
+**Note before training**
+
+The training GPU implements environment for this README is as follows: models are trained and tested with the NVIDIA RTX 3060 GPU with 6GB memory, NVIDIA RTX 2080Ti GPU with 11GB memory and NVIDIA RTX 6000 (×2) GPU with 22GB (total 44GB, distributed training) memory. **The above GPUs can all be trained normally**.
+
+#### 2.1 Start Your First Training (Using cifar10 as an Example, Single GPU Mode)
1. **Import the Dataset**
@@ -183,7 +208,9 @@ The training GPU implements environment for this README is as follows: models ar
**↓↓↓↓↓↓↓↓↓↓The following is an explanation of various training methods and detailed training parameters↓↓↓↓↓↓↓↓↓↓**
-#### Normal Training
+#### 2.2 Normal Training
+
+##### 2.2.1 Command Training
1. Take the `landscape` dataset as an example and place the dataset files in the `datasets` folder. The overall path of the dataset should be `/your/path/datasets/landscape`, the images path should be `/your/path/datasets/landscape/images`, and the image files should be located at `/your/path/datasets/landscape/images/*.jpg`.
@@ -238,7 +265,45 @@ The training GPU implements environment for this README is as follows: models ar
python train.py --pretrain --pretrain_path /your/pretrain/path/model.pt --sample ddpm --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```
-#### Distributed Training
+##### 2.2.2 Python Training
+
+```python
+from model.trainers.dm import DMTrainer
+from tools.train import init_train_args
+
+# Function 1
+# Initialize arguments
+args = init_train_args()
+# Customize your parameters, or you can configure them by entering the init_train_args method
+setattr(args, "conditional", True) # True for conditional training, False for non-conditional training
+setattr(args, "sample", "ddpm") # Sampler
+setattr(args, "network", "unet") # Deep learning network
+setattr(args, "epochs", 300) # Number of iterations
+setattr(args, "image_size", 64) # Image size
+setattr(args, "result_path", "/your/dataset/path/") # Dataset path
+setattr(args, "result_path", "/your/save/path/") # Result path
+setattr(args, "vis", True) # Enable visualization
+# ...
+# OR use args["parameter_name"] = "your setting"
+# Start training
+DMTrainer(args=args).train()
+
+# Function 2
+args = init_train_args()
+# Input args and update some params
+DMTrainer(args=args, dataset_path="/your/dataset/path/").train()
+
+# Function 3
+DMTrainer(
+ conditional=True, sample="ddpm", dataset_path="/your/dataset/path/",
+ network="unet", epochs=300, image_size=64, result_path="/your/save/path/",
+ vis=True, # Any params...
+).train()
+```
+
+#### 2.3 Distributed Training
+
+##### 2.3.1 Command Training
1. The basic configuration is similar to regular training, but note that enabling distributed training requires setting `--distributed`. To prevent arbitrary use of distributed training, we have several conditions for enabling distributed training, such as `args.distributed`, `torch.cuda.device_count() > 1`, and `torch.cuda.is_available()`.
@@ -262,7 +327,66 @@ The training GPU implements environment for this README is as follows: models ar
![IDDM Distributed Training](assets/IDDM_training.png)
-#### Training Parameters
+##### 2.3.2 Python Training
+
+```python
+from torch import multiprocessing as mp
+from model.trainers.dm import DMTrainer
+from tools.train import init_train_args
+
+# Function 1
+# Initialize arguments
+args = init_train_args()
+gpus = torch.cuda.device_count()
+# Customize your parameters, or you can configure them by entering the init_train_args method
+setattr(args, "distributed", True) # Enable distributed training
+setattr(args, "world_size", 2) # Number of distributed nodes
+setattr(args, "conditional", True) # True for conditional training, False for non-conditional training
+setattr(args, "sample", "ddpm") # Sampler
+setattr(args, "network", "unet") # Deep learning network
+setattr(args, "epochs", 300) # Number of iterations
+setattr(args, "image_size", 64) # Image size
+setattr(args, "result_path", "/your/dataset/path/") # Dataset path
+setattr(args, "result_path", "/your/save/path/") # Result path
+setattr(args, "vis", True) # Enable visualization
+# ...
+# OR use args["parameter_name"] = "your setting"
+# Start training
+mp.spawn(DMTrainer(args=args, dataset_path="/your/dataset/path/").train, nprocs=gpus)
+
+# Function 2
+args = init_train_args()
+# Input args and update some params
+mp.spawn(DMTrainer(args=args, dataset_path="/your/dataset/path/").train, nprocs=gpus)
+
+# Function 3
+mp.spawn(DMTrainer(
+ conditional=True, sample="ddpm", dataset_path="/your/dataset/path/",
+ network="unet", epochs=300, image_size=64, result_path="/your/save/path/",
+ vis=True, # Any params...
+).train, nprocs=gpus)
+```
+
+#### 2.4 Model Repositories
+
+**Note**: The model repo will continue to update pre-trained models.
+
+##### 2.4.1 Diffusion Model Pre-training Model
+
+| Model Name | Conditional | Datasets | Model Size | Download Link |
+| :---------------------------: | :---------: | :-----------: | :--------: | :----------------------------------------------------------: |
+| `celebahq-120-weight.pt` | ✓ | CelebA-HQ | 120×120 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/celebahq-120-weight.pt) |
+| `animate-ganyu-120-weight.pt` | ✓ | Animate-ganyu | 120×120 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-ganyu-120-weight.pt) |
+| `neu-120-weight.pt` | ✓ | NEU-DET | 120×120 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt) |
+| `neu-cls-64-weight.pt` | ✓ | NEU-CLS | 64×64 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.7/neu-cls-64-weight.pt) |
+| `cifar-64-weight.pt` | ✓ | Cifar-10 | 64×64 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt) |
+| `animate-face-64-weight.pt` | ✓ | Animate-face | 64×64 | [Download](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-face-64-weight.pt) |
+
+##### 2.4.2 Super Resolution Pre-trained Model
+
+Coming soon :-)
+
+#### 2.5 Training Parameters
**Parameter Explanation**
@@ -305,7 +429,11 @@ The training GPU implements environment for this README is as follows: models ar
-### Generation
+### 3 Generation
+
+#### 3.1 Start Your First Generation
+
+##### 3.1.1 Command Generation
1. Open the `generate.py` file and locate the `--weight_path` parameter. Modify the path in the parameter to the path of your model weights, for example `/your/path/weight/model.pt`.
@@ -337,7 +465,26 @@ The training GPU implements environment for this README is as follows: models ar
3. Wait for the generation process to complete.
-#### Generation Parameters
+##### 3.1.2 Python generation
+
+```python
+from tools.generate import Generator, init_generate_args
+
+# Initialize generation arguments, or you can configure them by entering the init_generate_args method
+args = init_generate_args()
+# Customize your parameters
+args["weight_path"] = "/your/model/path/model.pt"
+args["result_path"] = "/your/save/path/"
+# ...
+# args["parameter_name"] = "your setting"
+gen_model = Generator(gen_args=args, deploy=False)
+# Number of generations
+num_images = 2
+for i in range(num_images):
+ gen_model.generate(index=i)
+```
+
+#### 3.2 Generation Parameters
**Parameter Explanation**
@@ -350,6 +497,7 @@ The training GPU implements environment for this README is as follows: models ar
| --num_images | | Number of generated images | int | Number of images to generate |
| --weight_path | | Path to model weights | str | Path to the model weights file, required for network generation |
| --result_path | | Save path | str | Path to save the generated images |
+| --use_gpu | | Set the use GPU | int | Set the use GPU in generate, input is GPU's id |
| --sample | | Sampling method | str | Set the sampling method type, currently supporting DDPM and DDIM. **(No need to set for models after version 1.1.1)** |
| --network | | Training network | str | Set the training network, currently supporting UNet, CSPDarkUNet. **(No need to set for models after version 1.1.1)** |
| --act | | Activation function | str | Activation function selection. Currently supports gelu, silu, relu, relu6 and lrelu. If you do not set the same activation function as the model, mosaic phenomenon will occur. **(No need to set for models after version 1.1.1)** |
@@ -357,80 +505,43 @@ The training GPU implements environment for this README is as follows: models ar
| --class_name | ✓ | Class name | int | Index of the class to generate images. if class name is `-1`, the model would output one image per class. |
| --cfg_scale | ✓ | Classifier-free guidance weight | int | Weight for classifier-free guidance interpolation, for better generation model performance |
-### Result
-
-We conducted training on the following four datasets using the `DDPM` sampler with an image size of `64*64`. we also enabled `conditional`, using the `gelu` activation function, `linear` learning function and setting learning rate to `3e-4`. The datasets are `cifar10`, `NEUDET`, `NRSD-MN`, and `WOOD` in `300` epochs. The results are shown in the following figure:
-
-#### cifar10
-![cifar_244_ema](assets/cifar_244_ema.jpg)
-![cifar_294_ema](assets/cifar_294_ema.jpg)
+### 4 Result
-#### NEUDET
+We conducted training on the following 5 datasets using the `DDPM` sampler with an image size of `64*64`. we also enabled `conditional`, using the `gelu` activation function, `linear` learning function and setting learning rate to `3e-4`. The datasets are `cifar10`, `NEUDET`, `NRSD-MN`, `WOOD` and `Animate face` in `300` epochs. The results are shown in the following figure:
-![neudet_290_ema](assets/neudet_290_ema.jpg)
+#### 4.1 cifar10 dataset
-![neudet_270_ema](assets/neudet_270_ema.jpg)
+![cifar_244_ema](assets/cifar_244_ema.jpg)![cifar_294_ema](assets/cifar_294_ema.jpg)
-![neudet_276_ema](assets/neudet_276_ema.jpg)
+#### 4.2 NEU-DET dataset
-![neudet_265_ema](assets/neudet_265_ema.jpg)
+![neudet_290_ema](assets/neudet_290_ema.jpg)![neudet_270_ema](assets/neudet_270_ema.jpg)![neudet_276_ema](assets/neudet_276_ema.jpg)![neudet_265_ema](assets/neudet_265_ema.jpg)![neudet_240_ema](assets/neudet_240_ema.jpg)![neudet_244_ema](assets/neudet_244_ema.jpg)![neudet_245_ema](assets/neudet_245_ema.jpg)![neudet_298_ema](assets/neudet_298_ema.jpg)
-![neudet_240_ema](assets/neudet_240_ema.jpg)
+#### 4.3 NRSD dataset
-![neudet_244_ema](assets/neudet_244_ema.jpg)
+![nrsd_180_ema](assets/nrsd_180_ema.jpg)![nrsd_188_ema](assets/nrsd_188_ema.jpg)![nrsd_194_ema](assets/nrsd_194_ema.jpg)![nrsd_203_ema](assets/nrsd_203_ema.jpg)![nrsd_210_ema](assets/nrsd_210_ema.jpg)![nrsd_217_ema](assets/nrsd_217_ema.jpg)![nrsd_218_ema](assets/nrsd_218_ema.jpg)![nrsd_248_ema](assets/nrsd_248_ema.jpg)![nrsd_276_ema](assets/nrsd_276_ema.jpg)![nrsd_285_ema](assets/nrsd_285_ema.jpg)![nrsd_295_ema](assets/nrsd_295_ema.jpg)![nrsd_298_ema](assets/nrsd_298_ema.jpg)
-![neudet_298_ema](assets/neudet_298_ema.jpg)
-
-#### NRSD
-
-![nrsd_180_ema](assets/nrsd_180_ema.jpg)
-
-![nrsd_188_ema](assets/nrsd_188_ema.jpg)
-
-![nrsd_194_ema](assets/nrsd_194_ema.jpg)
-
-![nrsd_203_ema](assets/nrsd_203_ema.jpg)
-
-![nrsd_210_ema](assets/nrsd_210_ema.jpg)
-
-![nrsd_217_ema](assets/nrsd_217_ema.jpg)
-
-![nrsd_218_ema](assets/nrsd_218_ema.jpg)
-
-![nrsd_248_ema](assets/nrsd_248_ema.jpg)
-
-![nrsd_276_ema](assets/nrsd_276_ema.jpg)
-
-![nrsd_285_ema](assets/nrsd_285_ema.jpg)
-
-![nrsd_298_ema](assets/nrsd_298_ema.jpg)
-
-#### WOOD
+#### 4.4 WOOD dataset
![wood_495](assets/wood_495.jpg)
-#### Animate face (~~JUST FOR FUN~~)
+#### 4.5 Animate face dataset (~~JUST FOR FUN~~)
-![model_428_ema](assets/animate_face_428_ema.jpg)
+![model_428_ema](assets/animate_face_428_ema.jpg)![model_440_ema](assets/animate_face_440_ema.jpg)![model_488_ema](assets/animate_face_488_ema.jpg)![model_497_ema](assets/animate_face_497_ema.jpg)![model_499_ema](assets/animate_face_499_ema.jpg)![model_459_ema](assets/animate_face_459_ema.jpg)
-![model_488_ema](assets/animate_face_488_ema.jpg)
-
-![model_497_ema](assets/animate_face_497_ema.jpg)
-
-![model_499_ema](assets/animate_face_499_ema.jpg)
-
-![model_459_ema](assets/animate_face_459_ema.jpg)
-
-#### Base on the 64×64 model to generate 160×160 (every size) images
+#### 4.6 Base on the 64×64 model to generate 160×160 (every size) images (Industrial surface defect generation only)
Of course, based on the 64×64 U-Net model, we generate 160×160 `NEU-DET` images in the `generate.py` file (single output, each image occupies 21GB of GPU memory). **Attention this [[issues]](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9#issuecomment-1886422210)**! If it's an image with defect textures where the features are not clear, generating a large size directly might not have these issues, such as in NRSD or NEU datasets. However, if the image contains a background with specific distinctive features, you may need to use super-resolution or resizing to increase the size, for example, in Cifar10, CelebA-HQ, etc. **If you really need large-sized images, you can directly train with large pixel images if there is enough GPU memory.** Detailed images are as follows:
![model_499_ema](assets/neu160_0.jpg)![model_499_ema](assets/neu160_1.jpg)![model_499_ema](assets/neu160_2.jpg)![model_499_ema](assets/neu160_3.jpg)![model_499_ema](assets/neu160_4.jpg)![model_499_ema](assets/neu160_5.jpg)
-### Evaluation
+
+### 5 Evaluation
+
+#### 5.1 Start Your First Evaluation
1. During the data preparation stage, use `generate.py` to create the dataset. The amount and size of the generated dataset should be similar to the training set (**Note**: The training set required for evaluation should be resized to the size used during training, which is the `image_size`. For example, if the training set path is `/your/path/datasets/landscape` with an image size of **256**, and the generated set path is `/your/path/generate/landscape` with a size of 64, use the `resize` method to convert the images in the training set path to **64**. The new evaluation training set path will be `/your/new/path/datasets/landscape`).
@@ -452,7 +563,7 @@ Of course, based on the 64×64 U-Net model, we generate 160×160 `NEU-DET` image
python FID_calculator_plus.py /your/input/path /your/output/path --save_stats
```
-#### Evaluation Parameters
+#### 5.2 Evaluation Parameters
| **Parameter Name** | Usage | Parameter Type | Explanation |
| ------------------ | --------------------------- | :------------: | ------------------------------------------------------------ |
@@ -463,7 +574,9 @@ Of course, based on the 64×64 U-Net model, we generate 160×160 `NEU-DET` image
| --save_stats | Save stats | bool | Generate npz archives from the sample directory |
| --use_gpu | Specify GPU | int | Generally used to set the specific GPU for training, input the GPU number |
-### About Citation
+
+
+### 6 About Citation
If this project is used for experiments in an academic paper, where possible please cite our project appropriately and we appreciate this. The specific citation format can be found at **[this website](https://zenodo.org/records/10866128)**.
@@ -479,7 +592,11 @@ If this project is used for experiments in an academic paper, where possible ple
}
```
-### Acknowledgements
+**Citation detail**: ![image-20241124174339833](assets/image-citation.png)
+
+
+
+### 7 Acknowledgements
**People**:
@@ -489,8 +606,4 @@ If this project is used for experiments in an academic paper, where possible ple
[@JetBrains](https://www.jetbrains.com/)
-[@Python](https://www.python.org/)
-
-[@Pytorch](https://pytorch.org/)
-
![JetBrains logo](assets/jetbrains.svg)
\ No newline at end of file
diff --git a/README_zh.md b/README_zh.md
index 5fb06de..bc569dd 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -27,17 +27,25 @@
+
+
+
+
+
+
[English Document](README.md)
-### 关于模型
+
+
+### 0 关于模型
该扩散模型为经典的ddpm、ddim和plms,来源于论文《**[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)**》、《**[Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)**》和《**[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://openreview.net/forum?id=PlKWVd2yBkY)**》。
我们将此项目命名为IDDM: Integrated Design Diffusion Model,中文名为集成设计扩散模型。在此项目中进行模型复现、训练器和生成器编写、部分算法和网络结构的改进与优化,该仓库**持续维护**。
-如果有任何问题,请先到此[**issue**](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9)进行问题查询,若无法解决可以加入我们的QQ群:949120343、开启新issue提问或联系我的邮箱:chenyu1998424@gmail.com/chairc1998@163.com
+如果有任何问题,请先到此[**issue**](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9)进行问题查询,若无法解决可以加入我们的QQ群:949120343、开启新issue提问或联系我的邮箱:chenyu1998424@gmail.com/chairc1998@163.com。**如果你认为我的项目有意思请给我点一颗⭐⭐⭐Star⭐⭐⭐吧。**
**本仓库整体结构**
@@ -45,6 +53,7 @@
Integrated Design Diffusion Model
├── config
│ ├── choices.py
+│ ├── model_list.py
│ ├── setting.py
│ └── version.py
├── datasets
@@ -69,11 +78,15 @@ Integrated Design Diffusion Model
│ │ ├── base.py
│ │ ├── cspdarkunet.py
│ │ └── unet.py
-│ └── samples
+│ ├── samples
+│ │ ├── base.py
+│ │ ├── ddim.py
+│ │ ├── ddpm.py
+│ │ └── plms.py
+│ └── trainers
│ ├── base.py
-│ ├── ddim.py
-│ ├── ddpm.py
-│ └── plms.py
+│ ├── dm.py
+│ └── sr.py
├── results
├── sr
│ ├── dataset.py
@@ -97,37 +110,49 @@ Integrated Design Diffusion Model
│ ├── initializer.py
│ ├── logger.py
│ ├── lr_scheduler.py
+│ ├── metrics.py
│ ├── processing.py
│ └── utils.py
├── webui
│ └──web.py
-└── weight
+└── weights
```
-### 接下来要做
-- [x] 1. 新增cosine学习率优化(2023-07-31)
-- [x] 2. 使用效果更优的U-Net网络模型(2023-11-09)
-- [x] 3. 更大尺寸的生成图像(2023-11-09)
-- [x] 4. 多卡分布式训练(2023-07-15)
-- [x] 5. 云服务器快速部署和接口(2023-08-28)
-- [x] 6. 增加DDIM采样方法(2023-08-03)
-- [x] 7. 支持其它图像生成(2023-09-16)
-- [x] 8. 低分辨率生成图像进行超分辨率增强[~~超分模型效果待定~~](2024-02-18)
-- [ ] 9. 使用Latent方式降低显存消耗
-- [x] 10. 重构model整体结构(2023-12-06)
-- [x] 11. 编写可视化webui界面(2024-01-23)
-- [x] 12. 增加PLMS采样方法(2024-03-12)
-- [x] 13. 增加FID方法验证图像质量(2024-05-06)
-- [x] 14. 增加生成图像Socket和网站服务部署(2024-11-13)
-### 训练
+### 1 接下来要做
+
+- [x] [2023-07-15] 增加多卡分布式训练
+- [x] [2023-07-31] 增加cosine学习率优化
+- [x] [2023-08-03] 增加DDIM采样方法
+- [x] [2023-08-28] 云服务器快速部署和接口
+- [x] [2023-09-16] 支持其它图像生成
+- [x] [2023-11-09] 增加效果更优的U-Net网络模型
+- [x] [2023-11-09] 支持更大尺寸的生成图像
+- [x] [2023-12-06] 重构model整体结构
+- [x] [2024-01-23] 增加可视化webui训练界面
+- [x] [2024-02-18] 支持低分辨率生成图像进行超分辨率增强[~~超分模型效果待定~~]
+- [x] [2024-03-12] 增加PLMS采样方法
+- [x] [2024-05-06] 增加FID方法验证图像质量
+- [x] [2024-06-11] 增加可视化webui生成界面
+- [x] [2024-07-07] 支持自定义图像长宽输入
+- [x] [2024-11-13] 增加生成图像Socket和网站服务部署
+- [x] [2024-11-26] 增加PSNR和SSIM方法验证超分图像质量
+- [x] [2024-12-10] 增加预训练模型下载
+- [x] [2024-12-25] 重构训练器结构
+- [ ] [预计2025-01-31] 增加Docker部署与镜像
+- [ ] [待定] 重构项目利用百度飞桨框架
+- [ ] [待定] ~~使用Latent方式降低显存消耗~~
-#### 注意
-本自README的训练GPU环境如下:使用具有6GB显存的NVIDIA RTX 3060显卡、具有11GB显存的NVIDIA RTX 2080Ti显卡和具有24GB(总计48GB,分布式训练)显存的NVIDIA RTX 6000(×2)显卡对模型进行训练和测试。**上述GPU均可正常训练**。
-#### 开始你的第一个训练(以cifar10为例,模式单卡)
+### 2 训练
+
+**训练前需注意**
+
+本自README的训练GPU环境如下:使用具有6GB显存的NVIDIA RTX 3060显卡、具有11GB显存的NVIDIA RTX 2080Ti显卡和具有22GB(总计44GB,分布式训练)显存的NVIDIA RTX 6000(×2)显卡对模型进行训练和测试。**上述GPU均可正常训练**。
+
+#### 2.1 开始你的第一个训练(以cifar10为例,模式单卡)
1. **导入数据集**
@@ -181,7 +206,9 @@ Integrated Design Diffusion Model
**↓↓↓↓↓↓↓↓↓↓下方为多种训练方式、训练详细参数讲解↓↓↓↓↓↓↓↓↓↓**
-#### 普通训练
+#### 2.2 普通训练
+
+##### 2.2.1 命令训练
1. 以`landscape`数据集为例,将数据集文件放入`datasets`文件夹中,该数据集的总路径如下`/your/path/datasets/landscape`,图片存放在`/your/path/datasets/landscape/images`,数据集图片路径如下`/your/path/datasets/landscape/images/*.jpg`
@@ -234,7 +261,45 @@ Integrated Design Diffusion Model
python train.py --pretrain --pretrain_path /your/pretrain/path/model.pt --sample ddpm --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```
-#### 分布式训练
+##### 2.2.2 Python脚本训练
+
+```python
+from model.trainers.dm import DMTrainer
+from tools.train import init_train_args
+
+# 方法一
+# 初始化参数
+args = init_train_args()
+# 自定义你的参数,也可以进入init_train_args方法配置
+setattr(args, "conditional", True) # True为条件训练,False为非条件训练
+setattr(args, "sample", "ddpm") # 采样器
+setattr(args, "network", "unet") # 深度学习网络
+setattr(args, "epochs", 300) # 迭代次数
+setattr(args, "image_size", 64) # 图像大小
+setattr(args, "dataset_path", "/你/的/数/据/集/路/径/") # 数据集保存路径
+setattr(args, "result_path", "/你/的/保/存/路/径/") # 结果保存路径
+setattr(args, "vis", True) # 开启可视化
+# ...
+# 或者使用args["参数名称"] = "你的设置"
+# 开启训练
+DMTrainer(args=args).train()
+
+# 方法二
+args = init_train_args()
+# 输入args,修改指定参数输入
+DMTrainer(args=args, dataset_path="/你/的/数/据/集").train()
+
+# 方法三
+DMTrainer(
+ conditional=True, sample="ddpm", dataset_path="/你/的/数/据/集/路/径/",
+ network="unet", epochs=300, image_size=64, result_path="/你/的/保/存/路/径/",
+ vis=True, # 任意参数...
+).train()
+```
+
+#### 2.3 分布式训练
+
+##### 2.3.1 命令训练
1. 基本配置与普通训练相似,值得注意的是开启分布式训练需要设置`--distributed`。为了防止随意设置分布式训练,我们为开启分布式训练设置了几个基本条件,例如`args.distributed`、`torch.cuda.device_count() > 1`和`torch.cuda.is_available()`。
@@ -258,10 +323,66 @@ Integrated Design Diffusion Model
![IDDM分布式训练过程](assets/IDDM_training.png)
+##### 2.3.2 Python脚本训练
+
+```python
+from torch import multiprocessing as mp
+from model.trainers.dm import DMTrainer
+from tools.train import init_train_args
+
+# 方法一
+# 初始化参数
+args = init_train_args()
+gpus = torch.cuda.device_count()
+# 自定义你的参数,也可以进入init_train_args方法配置
+setattr(args, "distributed", True) # 开启分布式训练
+setattr(args, "world_size", 2) # 训练结点个数
+setattr(args, "conditional", True) # True为条件训练,False为非条件训练
+setattr(args, "sample", "ddpm") # 采样器
+setattr(args, "network", "unet") # 深度学习网络
+setattr(args, "epochs", 300) # 迭代次数
+setattr(args, "image_size", 64) # 图像大小
+setattr(args, "dataset_path", "/你/的/数/据/集/路/径/") # 数据集保存路径
+setattr(args, "result_path", "/你/的/保/存/路/径/") # 结果保存路径
+setattr(args, "vis", True) # 开启可视化
+# ...
+# 或者使用args["参数名称"] = "你的设置"
+# 开启训练
+mp.spawn(DMTrainer(args=args).train, nprocs=gpus)
+
+# 方法二
+args = init_train_args()
+# 输入args,修改指定参数输入
+mp.spawn(DMTrainer(args=args, dataset_path="/你/的/数/据/集").train, nprocs=gpus)
+
+# 方法三
+mp.spawn(DMTrainer(
+ conditional=True, sample="ddpm", dataset_path="/你/的/数/据/集/路/径/",
+ network="unet", epochs=300, image_size=64, result_path="/你/的/保/存/路/径/",
+ vis=True, # 任意参数...
+).train, nprocs=gpus)
+```
+
+#### 2.4 模型库
+
+**注意**:模型库将持续更新预训练模型。
+
+##### 2.4.1 扩散模型预训练模型
+
+| 模型名称 | 是否条件训练 | 数据集 | 模型大小 | 下载链接 |
+| :---------------------------: | :----------: | :-----------: | :------: | :----------------------------------------------------------: |
+| `celebahq-120-weight.pt` | ✓ | CelebA-HQ | 120×120 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/celebahq-120-weight.pt) |
+| `animate-ganyu-120-weight.pt` | ✓ | Animate-ganyu | 120×120 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-ganyu-120-weight.pt) |
+| `neu-120-weight.pt` | ✓ | NEU-DET | 120×120 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt) |
+| `neu-cls-64-weight.pt` | ✓ | NEU-CLS | 64×64 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.7/neu-cls-64-weight.pt) |
+| `cifar-64-weight.pt` | ✓ | Cifar-10 | 64×64 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt) |
+| `animate-face-64-weight.pt` | ✓ | Animate-face | 64×64 | [模型下载](https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-face-64-weight.pt) |
+##### 2.4.2 超分辨率预训练模型
-#### 训练参数
+很快就来:-)
+#### 2.5 训练参数
**参数讲解**
@@ -304,8 +425,11 @@ Integrated Design Diffusion Model
+### 3 生成
-### 生成
+#### 3.1 开始你的第一个生成
+
+##### 3.1.1 命令生成
1. 打开`generate.py`文件,找到`--weight_path`参数,将参数中的路径修改为模型权重路径,例如`/your/path/weight/model.pt`
@@ -335,7 +459,26 @@ Integrated Design Diffusion Model
3. 等待生成即可
-#### 生成参数
+##### 3.1.2 Python脚本生成
+
+```python
+from tools.generate import Generator, init_generate_args
+
+# 初始化生成参数,也可以进入init_generate_args方法配置
+args = init_generate_args()
+# 自定义你的参数
+args["weight_path"] = "/你/的/模/型/路/径/model.pt"
+args["result_path"] = "/你/的/保/存/路/径/"
+# ...
+# args["参数名称"] = "你的设置"
+gen_model = Generator(gen_args=args, deploy=False)
+# 生成数量
+num_images = 2
+for i in range(num_images):
+ gen_model.generate(index=i)
+```
+
+#### 3.2 生成参数
**参数讲解**
@@ -348,6 +491,7 @@ Integrated Design Diffusion Model
| --num_images | | 生成图片个数 | int | 单次生成图片个数 |
| --weight_path | | 权重路径 | str | 模型权重路径,网络生成需要加载文件 |
| --result_path | | 保存路径 | str | 保存路径 |
+| --use_gpu | | 设置运行指定的GPU | int | 生成中设置指定的运行GPU,输入为GPU的编号 |
| --sample | | 采样方式 | str | 设置采样器类别,当前支持ddpm,ddim(**1.1.1版本后的模型可不用设置**) |
| --network | | 训练网络 | str | 设置训练网络,当前支持UNet,CSPDarkUNet(**1.1.1版本后的模型可不用设置**) |
| --act | | 激活函数 | str | 激活函数选择,目前支持gelu、silu、relu、relu6和lrelu。如果不选择,会产生马赛克现象(**1.1.1版本后的模型可不用设置**) |
@@ -357,79 +501,41 @@ Integrated Design Diffusion Model
-### 结果
-
-我们在以下4个数据集做了训练,开启`conditional`,采样器为`DDPM`,图片尺寸均为`64*64`,激活函数为`gelu`,学习率为`3e-4`,采用`线性`学习方法,迭代次数为`300`,分别是`cifar10`,`NEUDET`,`NRSD-MN`和`WOOD`。结果如下图所示:
-
-#### cifar10
+### 4 结果
-![cifar_244_ema](assets/cifar_244_ema.jpg)
+我们在以下5个数据集做了训练,开启`conditional`,采样器为`DDPM`,图片尺寸均为`64*64`,激活函数为`gelu`,学习率为`3e-4`,采用`线性`学习方法,迭代次数为`300`,分别是`cifar10`,`NEU-DET`,`NRSD-MN`,`WOOD`和`Animate face`。结果如下图所示:
-![cifar_294_ema](assets/cifar_294_ema.jpg)
+#### 4.1 cifar10数据集
-#### NEUDET
+![cifar_244_ema](assets/cifar_244_ema.jpg)![cifar_294_ema](assets/cifar_294_ema.jpg)
-![neudet_290_ema](assets/neudet_290_ema.jpg)
+#### 4.2 NEU-DET数据集
-![neudet_270_ema](assets/neudet_270_ema.jpg)
+![neudet_290_ema](assets/neudet_290_ema.jpg)![neudet_270_ema](assets/neudet_270_ema.jpg)![neudet_276_ema](assets/neudet_276_ema.jpg)![neudet_265_ema](assets/neudet_265_ema.jpg)![neudet_240_ema](assets/neudet_240_ema.jpg)![neudet_244_ema](assets/neudet_244_ema.jpg)![neudet_245_ema](assets/neudet_245_ema.jpg)![neudet_298_ema](assets/neudet_298_ema.jpg)
-![neudet_276_ema](assets/neudet_276_ema.jpg)
+#### 4.3 NRSD数据集
-![neudet_265_ema](assets/neudet_265_ema.jpg)
+![nrsd_180_ema](assets/nrsd_180_ema.jpg)![nrsd_188_ema](assets/nrsd_188_ema.jpg)![nrsd_194_ema](assets/nrsd_194_ema.jpg)![nrsd_203_ema](assets/nrsd_203_ema.jpg)![nrsd_210_ema](assets/nrsd_210_ema.jpg)![nrsd_217_ema](assets/nrsd_217_ema.jpg)![nrsd_218_ema](assets/nrsd_218_ema.jpg)![nrsd_248_ema](assets/nrsd_248_ema.jpg)![nrsd_276_ema](assets/nrsd_276_ema.jpg)![nrsd_285_ema](assets/nrsd_285_ema.jpg)![nrsd_295_ema](assets/nrsd_295_ema.jpg)![nrsd_298_ema](assets/nrsd_298_ema.jpg)
-![neudet_240_ema](assets/neudet_240_ema.jpg)
-
-![neudet_244_ema](assets/neudet_244_ema.jpg)
-
-![neudet_298_ema](assets/neudet_298_ema.jpg)
-
-#### NRSD
-
-![nrsd_180_ema](assets/nrsd_180_ema.jpg)
-
-![nrsd_188_ema](assets/nrsd_188_ema.jpg)
-
-![nrsd_194_ema](assets/nrsd_194_ema.jpg)
-
-![nrsd_203_ema](assets/nrsd_203_ema.jpg)
-
-![nrsd_210_ema](assets/nrsd_210_ema.jpg)
-
-![nrsd_217_ema](assets/nrsd_217_ema.jpg)
-
-![nrsd_218_ema](assets/nrsd_218_ema.jpg)
-
-![nrsd_248_ema](assets/nrsd_248_ema.jpg)
-
-![nrsd_276_ema](assets/nrsd_276_ema.jpg)
-
-![nrsd_285_ema](assets/nrsd_285_ema.jpg)
-
-![nrsd_298_ema](assets/nrsd_298_ema.jpg)
-
-#### WOOD
+#### 4.4 WOOD数据集
![wood_495](assets/wood_495.jpg)
-#### Animate face(整活生成)
+#### 4.5 Animate face数据集(~~整活生成~~)
-![model_428_ema](assets/animate_face_428_ema.jpg)
+![model_428_ema](assets/animate_face_428_ema.jpg)![model_440_ema](assets/animate_face_440_ema.jpg)![model_488_ema](assets/animate_face_488_ema.jpg)![model_497_ema](assets/animate_face_497_ema.jpg)![model_499_ema](assets/animate_face_499_ema.jpg)![model_459_ema](assets/animate_face_459_ema.jpg)
-![model_488_ema](assets/animate_face_488_ema.jpg)
+#### 4.6 基于64×64模型生成160×160(任意大尺寸)图像(仅限工业表面缺陷生成)
-![model_497_ema](assets/animate_face_497_ema.jpg)
+当然,我们根据64×64的基础模型,在`generate.py`文件中生成160×160的`NEU-DET`图片(单张输出,每张图片占用显存21GB)。**请注意这个**[[**issue**]](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9#issuecomment-1886422210):如果是缺陷纹理那种图片,特征物不明显的直接生成大尺寸就不会有这些问题,例如NRSD、NEU数据集。如果是含有背景有特定明显特征的则需要超分或者resize提升尺寸,例如Cifar10、CelebA-HQ等。**如果实在需要大尺寸图像,在显存足够的情况下直接训练大像素图片**。详细图片如下:
-![model_499_ema](assets/animate_face_499_ema.jpg)
-
-![model_459_ema](assets/animate_face_459_ema.jpg)
+![model_499_ema](assets/neu160_0.jpg)![model_499_ema](assets/neu160_1.jpg)![model_499_ema](assets/neu160_2.jpg)![model_499_ema](assets/neu160_3.jpg)![model_499_ema](assets/neu160_4.jpg)![model_499_ema](assets/neu160_5.jpg)
-#### 基于64×64模型生成160×160(任意大尺寸)图像
-当然,我们根据64×64的基础模型,在`generate.py`文件中生成160×160的`NEU-DET`图片(单张输出,每张图片占用显存21GB)。**请注意这个**[[**issue**]](https://github.com/chairc/Integrated-Design-Diffusion-Model/issues/9#issuecomment-1886422210):如果是缺陷纹理那种图片,特征物不明显的直接生成大尺寸就不会有这些问题,例如NRSD、NEU数据集。如果是含有背景有特定明显特征的则需要超分或者resize提升尺寸,例如Cifar10、CelebA-HQ等。**如果实在需要大尺寸图像,在显存足够的情况下直接训练大像素图片。**详细图片如下:
-![model_499_ema](assets/neu160_0.jpg)![model_499_ema](assets/neu160_1.jpg)![model_499_ema](assets/neu160_2.jpg)![model_499_ema](assets/neu160_3.jpg)![model_499_ema](assets/neu160_4.jpg)![model_499_ema](assets/neu160_5.jpg)
+### 5 评估
-### 评估
+#### 5.1 开始你的第一个评估
1. 数据准备阶段,使用`generate.py`生成数据集,数据集生成量应该与训练集的数量、尺寸相似(**注意**:评估时所需要的训练集应为进行了`resize`后的结果,即为训练时的`image_size`大小。例如,训练集的路径为`/your/path/datasets/landscape`,图片尺寸为**256**;生成集的路径为`/your/path/generate/landscape`,尺寸为64,使用`resize`方法将训练集路径中的图片转为**64**,此时新的评估用训练集路径为`/your/new/path/datasets/landscape`)。
@@ -451,7 +557,7 @@ Integrated Design Diffusion Model
python FID_calculator_plus.py /your/input/path /your/output/path --save_stats
```
-#### 评估参数
+#### 5.2 评估参数
| **参数名称** | 参数使用方法 | 参数类型 | 参数解释 |
| ------------- | ----------------- | :------: | ------------------------------------------------------------ |
@@ -464,7 +570,7 @@ Integrated Design Diffusion Model
-### 关于引用
+### 6 关于引用
如果在学术论文中使用该项目进行实验,在可能的情况下,请适当引用我们的项目,为此我们表示感谢。具体引用格式可访问[**此网站**](https://zenodo.org/records/10866128)。
@@ -480,7 +586,11 @@ Integrated Design Diffusion Model
}
```
-### 致谢
+**引用详情可以参考此处**:![image-20241124174257466](assets/image-citation.png)
+
+
+
+### 7 致谢
**人员**:
@@ -490,9 +600,5 @@ Integrated Design Diffusion Model
[@JetBrains](https://www.jetbrains.com/)
-[@Python](https://www.python.org/)
-
-[@Pytorch](https://pytorch.org/)
-
![JetBrains logo](assets/jetbrains.svg)
diff --git a/assets/animate_face_440_ema.jpg b/assets/animate_face_440_ema.jpg
new file mode 100644
index 0000000..5ba5a5c
Binary files /dev/null and b/assets/animate_face_440_ema.jpg differ
diff --git a/assets/image-citation.png b/assets/image-citation.png
new file mode 100644
index 0000000..5ebd395
Binary files /dev/null and b/assets/image-citation.png differ
diff --git a/assets/neudet_245_ema.jpg b/assets/neudet_245_ema.jpg
new file mode 100644
index 0000000..0e76193
Binary files /dev/null and b/assets/neudet_245_ema.jpg differ
diff --git a/assets/nrsd_295_ema.jpg b/assets/nrsd_295_ema.jpg
new file mode 100644
index 0000000..4aa22b8
Binary files /dev/null and b/assets/nrsd_295_ema.jpg differ
diff --git a/config/choices.py b/config/choices.py
index 6a6a853..8e782e5 100644
--- a/config/choices.py
+++ b/config/choices.py
@@ -24,6 +24,7 @@
image_format_choices = ["png", "jpg", "jpeg", "webp", "tif"]
noise_schedule_choices = ["linear", "cosine", "sqrt_linear", "sqrt"]
loss_func_choices = ["mse", "l1", "huber", "smooth_l1"]
+sr_loss_func_choices = ["mse"]
sr_network_choices = ["srv1"]
image_type_choices = {"RGB": 3, "GRAY": 1}
diff --git a/model/trainers/__init__.py b/model/trainers/__init__.py
new file mode 100644
index 0000000..7e8a24b
--- /dev/null
+++ b/model/trainers/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+"""
+ @Date : 2024/9/18 9:59
+ @Author : chairc
+ @Site : https://github.com/chairc
+"""
+from .base import Trainer
+from .dm import DMTrainer
+from .sr import SRTrainer
diff --git a/model/trainers/base.py b/model/trainers/base.py
new file mode 100644
index 0000000..b3137fc
--- /dev/null
+++ b/model/trainers/base.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+"""
+ @Date : 2024/9/18 9:59
+ @Author : chairc
+ @Site : https://github.com/chairc
+"""
+import argparse
+import logging
+import coloredlogs
+
+logger = logging.getLogger(__name__)
+coloredlogs.install(level="INFO")
+
+
+class Trainer:
+ """
+ Base trainer
+ """
+
+ def __init__(self, args=None, **kwargs):
+ """
+ Initialize trainer
+ :param args: Args parser
+ :param kwargs: Parameters of trainer
+ """
+ # Can be set input params
+ self.args, self.kwargs, self.args_flag, self.rank = args, kwargs, False, None
+ # Check input params valid
+ if not kwargs and args is None:
+ raise ValueError("Trainer must provide arguments")
+ # New argparse
+ if self.args is None:
+ self.args_flag = True
+ self.args = argparse.ArgumentParser().parse_args()
+ logger.info(msg="[Note]: Trainer initializer successfully. But 'args' is None")
+
+ # Random seed
+ self.seed = self.check_args_and_kwargs(kwarg="seed", default=0)
+ # Network
+ self.network = self.check_args_and_kwargs(kwarg="network", default="unet")
+ # Batch size
+ self.batch_size = self.check_args_and_kwargs(kwarg="batch_size", default=2)
+ # Number of workers
+ self.num_workers = self.check_args_and_kwargs(kwarg="num_workers", default=0)
+ # Input image size
+ self.image_size = self.check_args_and_kwargs(kwarg="image_size", default=64)
+ # Number of epochs
+ self.epochs = self.check_args_and_kwargs(kwarg="epochs", default=300)
+ # Whether to enable automatic mixed precision training
+ self.amp = self.check_args_and_kwargs(kwarg="amp", default=False)
+ # Select optimizer
+ self.optim = self.check_args_and_kwargs(kwarg="optim", default="adamw")
+ # Loss function
+ self.loss_name = self.check_args_and_kwargs(kwarg="loss", default="mse")
+ # Select activation function
+ self.act = self.check_args_and_kwargs(kwarg="act", default="gelu")
+ # Learning rate
+ self.init_lr = self.check_args_and_kwargs(kwarg="lr", default=3e-4)
+ # Learning rate function
+ self.lr_func = self.check_args_and_kwargs(kwarg="lr_func", default="linear")
+ # Saving path
+ self.result_path = self.check_args_and_kwargs(kwarg="result_path", default="")
+ # Save model interval
+ self.save_model_interval = self.check_args_and_kwargs(kwarg="save_model_interval", default=False)
+ # Save model interval and save it every X epochs
+ self.save_model_interval_epochs = self.check_args_and_kwargs(kwarg="save_model_interval_epochs", default=10)
+ # Save model interval in the start epoch
+ self.start_model_interval = self.check_args_and_kwargs(kwarg="start_model_interval", default=-1)
+ # Save image format
+ self.image_format = self.check_args_and_kwargs(kwarg="image_format", default="png")
+ # Resume training
+ self.resume = self.check_args_and_kwargs(kwarg="resume", default=False)
+ # Resume training epoch num
+ self.start_epoch = self.check_args_and_kwargs(kwarg="start_epoch", default=-1)
+ # Pretrain
+ self.pretrain = self.check_args_and_kwargs(kwarg="pretrain", default=False)
+ # Pretrain path
+ self.pretrain_path = self.check_args_and_kwargs(kwarg="pretrain_path", default="")
+ # Set the use GPU in normal training
+ self.use_gpu = self.check_args_and_kwargs(kwarg="use_gpu", default=0)
+ # Enable distributed training
+ self.distributed = self.check_args_and_kwargs(kwarg="distributed", default=False)
+ # Set the main GPU
+ self.main_gpu = self.check_args_and_kwargs(kwarg="main_gpu", default=0)
+ # Number of distributed node
+ self.world_size = self.check_args_and_kwargs(kwarg="world_size", default=2)
+
+ # Default params
+ self.results_dir = None
+ self.results_tb_dir = None
+ self.results_logging = None
+ self.results_vis_dir = None
+ self.device = None
+ self.save_models = None
+ self.model = None
+ self.ema = None
+ self.ema_model = None
+ self.epoch = None
+ self.optimizer = None
+ self.scaler = None
+ self.loss_func = None
+ self.tb_logger = None
+
+ def check_args_and_kwargs(self, kwarg, default=None):
+ """
+ Check args with **kwargs
+ :param kwarg: **kwargs params
+ :param default: Default params
+ :return: Used params
+ """
+ # Prevent loading parameters from failing and call default values
+ if self.args_flag:
+ value = self.kwargs.get(kwarg, default)
+ else:
+ # Get the self.args
+ arg = getattr(self.args, kwarg)
+ value = self.kwargs.get(kwarg, arg)
+ # Load the params
+ if self.kwargs.get(kwarg) is not None or self.args_flag:
+ # The value of kwargs modifies the value of args
+ setattr(self.args, kwarg, value)
+ logger.info(msg=f"[Note]: args.{kwarg} already set => {value}")
+ return value
+
+ def train(self, rank=None):
+ """
+ Training method
+ :param rank: Device id
+ """
+ # Init rank
+ self.rank = rank
+
+ # Training
+ self.before_train()
+ self.train_in_epochs()
+ self.after_train()
+
+ def before_train(self):
+ """
+ Before training method
+ """
+ pass
+
+ def train_in_epochs(self):
+ """
+ Train in epochs method
+ """
+ self.before_iter()
+ self.train_in_iter()
+ self.after_iter()
+
+ def before_iter(self):
+ """
+ Before training one iter method
+ """
+ pass
+
+ def train_in_iter(self):
+ """
+ Train in one iter method
+ """
+ pass
+
+ def after_iter(self):
+ """
+ After training one iter
+ """
+ pass
+
+ def after_train(self):
+ """
+ After training method
+ """
+ pass
diff --git a/model/trainers/dm.py b/model/trainers/dm.py
new file mode 100644
index 0000000..5e64272
--- /dev/null
+++ b/model/trainers/dm.py
@@ -0,0 +1,331 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+"""
+ @Date : 2024/9/18 13:10
+ @Author : chairc
+ @Site : https://github.com/chairc
+"""
+import os
+import sys
+import copy
+import logging
+import coloredlogs
+import numpy as np
+import torch
+
+from torch import nn as nn
+from torch import distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+from torch.cuda.amp import autocast
+from tqdm import tqdm
+
+sys.path.append(os.path.dirname(sys.path[0]))
+from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
+from model.modules.ema import EMA
+from utils.check import check_image_size, check_pretrain_path, check_is_distributed
+from utils.dataset import get_dataset
+from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
+ sample_initializer, lr_initializer, amp_initializer, loss_initializer, classes_initializer
+from utils.utils import plot_images, save_images, setup_logging, save_train_logging, download_model_pretrain_model
+from utils.checkpoint import load_ckpt, save_ckpt
+from model.trainers.base import Trainer
+
+logger = logging.getLogger(__name__)
+coloredlogs.install(level="INFO")
+
+
+class DMTrainer(Trainer):
+ """
+ Diffusion model trainer
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialize diffusion model trainer
+ :param kwargs: Parameters of trainer
+ """
+ super(DMTrainer, self).__init__(**kwargs)
+ # Can be set input params
+ # Check args is None, and input kwargs in initialize diffusion model trainer
+ # e.g. trainer = DMTrainer(run_name="dm-temp", dataset_path="/your/dataset/path/dir")
+ # Run name
+ self.run_name = self.check_args_and_kwargs(kwarg="run_name", default="df")
+ # Whether to enable conditional training
+ self.conditional = self.check_args_and_kwargs(kwarg="conditional", default=False)
+ # Sample type
+ self.sample = self.check_args_and_kwargs(kwarg="sample", default="ddpm")
+ # Dataset path
+ self.dataset_path = self.check_args_and_kwargs(kwarg="dataset_path", default="")
+ # Enable data visualization
+ self.vis = self.check_args_and_kwargs(kwarg="vis", default=True)
+ # Number of visualization images generated
+ self.num_vis = self.check_args_and_kwargs(kwarg="num_vis", default=-1)
+ # Noise schedule
+ self.noise_schedule = self.check_args_and_kwargs(kwarg="noise_schedule", default="linear")
+ # classifier-free guidance interpolation weight, users can better generate model effect
+ self.cfg_scale = self.check_args_and_kwargs(kwarg="cfg_scale", default=3)
+
+ # Default params
+ self.num_classes = None
+ self.diffusion = None
+ self.pbar = None
+ self.dataloader = None
+ self.len_dataloader = None
+
+ def before_train(self):
+ """
+ Before training diffusion model method
+ """
+ # =================================Before training=================================
+ logger.info(msg=f"[{self.rank}]: Start diffusion model training")
+ # Output params to console
+ logger.info(msg=f"[{self.rank}]: Input params: {self.args}")
+ # Step1: Set path and create log
+ # Create data logging path
+ self.results_logging = setup_logging(save_path=self.result_path, run_name=self.run_name)
+ self.results_dir = self.results_logging[1]
+ self.results_vis_dir = self.results_logging[2]
+ self.results_tb_dir = self.results_logging[3]
+ # Tensorboard
+ self.tb_logger = SummaryWriter(log_dir=self.results_tb_dir)
+ # Train log
+ self.args = save_train_logging(arg=self.args, save_path=self.results_dir)
+
+ # Step2: Get the parameters of the initializer and args
+ # Initialize the seed
+ seed_initializer(seed_id=self.seed)
+ # Input image size
+ self.image_size = check_image_size(image_size=self.image_size)
+ # Number of classes
+ self.num_classes = classes_initializer(dataset_path=self.dataset_path)
+ # Initialize and save the model identification bit
+ # Check here whether it is single-GPU training or multi-GPU training
+ self.save_models = True
+ # Whether to enable distributed training
+ if check_is_distributed(distributed=self.distributed):
+ self.distributed = True
+ # Set address and port
+ os.environ["MASTER_ADDR"] = MASTER_ADDR
+ os.environ["MASTER_PORT"] = MASTER_PORT
+ # The total number of processes is equal to the number of graphics cards
+ dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=self.rank,
+ world_size=self.world_size)
+ # Set device ID
+ self.device = device_initializer(device_id=self.rank, is_train=True)
+ # There may be random errors, using this function can reduce random errors in cudnn
+ # torch.backends.cudnn.deterministic = True
+ # Synchronization during distributed training
+ dist.barrier()
+ # If the distributed training is not the main GPU, the save model flag is False
+ if dist.get_rank() != self.main_gpu:
+ self.save_models = False
+ logger.info(msg=f"[{self.device}]: Successfully Use distributed training.")
+ else:
+ self.distributed = False
+ # Run device initializer
+ self.device = device_initializer(device_id=self.use_gpu, is_train=True)
+ logger.info(msg=f"[{self.device}]: Successfully Use normal training.")
+
+ # =================================About model initializer=================================
+ # Step3: Init model
+ # Network
+ Network = network_initializer(network=self.network, device=self.device)
+ # Model
+ if not self.conditional:
+ self.model = Network(device=self.device, image_size=self.image_size, act=self.act).to(self.device)
+ else:
+ self.model = Network(num_classes=self.num_classes, device=self.device, image_size=self.image_size,
+ act=self.act).to(self.device)
+ # Distributed training
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(module=self.model, device_ids=[self.device],
+ find_unused_parameters=True)
+ # Model optimizer
+ self.optimizer = optimizer_initializer(model=self.model, optim=self.optim, init_lr=self.init_lr,
+ device=self.device)
+ # Resume training
+ if self.resume:
+ ckpt_path = None
+ # Determine which checkpoint to load
+ # 'start_epoch' is correct
+ if self.start_epoch is not None:
+ ckpt_path = os.path.join(self.results_dir, f"ckpt_{str(self.start_epoch - 1).zfill(3)}.pt")
+ # Parameter 'ckpt_path' is None in the train mode
+ if ckpt_path is None:
+ ckpt_path = os.path.join(self.results_dir, "ckpt_last.pt")
+ self.start_epoch = load_ckpt(ckpt_path=ckpt_path, model=self.model, device=self.device,
+ optimizer=self.optimizer, is_distributed=self.distributed,
+ conditional=self.conditional)
+ logger.info(msg=f"[{self.device}]: Successfully load resume model checkpoint.")
+ else:
+ # Pretrain mode
+ if self.pretrain:
+ # TODO: If pretrain path is none, download the official pretrain model
+ if check_pretrain_path(pretrain_path=self.pretrain_path):
+ # If you want to train on a specified data set, such as neu or cifar 10
+ # You can set the df_type to exp and add model_name="neu-cls" or model_name="cifar10"
+ self.pretrain_path = download_model_pretrain_model(pretrain_type="df", network=self.network,
+ conditional=self.conditional,
+ image_size=self.image_size, df_type="default")
+ load_ckpt(ckpt_path=self.pretrain_path, model=self.model, device=self.device, is_pretrain=self.pretrain,
+ is_distributed=self.distributed, conditional=self.conditional)
+ logger.info(msg=f"[{self.device}]: Successfully load pretrain model checkpoint.")
+ self.start_epoch = 0
+ # Set harf-precision
+ self.scaler = amp_initializer(amp=self.amp, device=self.device)
+ # Loss function
+ self.loss_func = loss_initializer(loss_name=self.loss_name, device=self.device)
+ # Initialize the diffusion model
+ self.diffusion = sample_initializer(sample=self.sample, image_size=self.image_size, device=self.device,
+ schedule_name=self.noise_schedule)
+ # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
+ self.ema = EMA(beta=EMA_BETA)
+ # EMA model
+ self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
+
+ # =================================About data=================================
+ # Step4: Set data
+ # Dataloader
+ self.dataloader = get_dataset(image_size=self.image_size, dataset_path=self.dataset_path,
+ batch_size=self.batch_size, num_workers=self.num_workers,
+ distributed=self.distributed)
+ # Number of dataset batches in the dataloader
+ self.len_dataloader = len(self.dataloader)
+
+ def train_in_epochs(self):
+ """
+ Train in epochs diffusion model method
+ """
+ # Step5: Training
+ logger.info(msg=f"[{self.device}]: Start training.")
+ # Start iterating
+ for self.epoch in range(self.start_epoch, self.epochs):
+ self.before_iter()
+ self.train_in_iter()
+ self.after_iter()
+
+ def before_iter(self):
+ """
+ Before training one iter diffusion model method
+ """
+ logger.info(msg=f"[{self.device}]: Start epoch {self.epoch}:")
+ # Set learning rate
+ current_lr = lr_initializer(lr_func=self.lr_func, optimizer=self.optimizer, epoch=self.epoch,
+ epochs=self.epochs, init_lr=self.init_lr, device=self.device)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Current LR", scalar_value=current_lr, global_step=self.epoch)
+ self.pbar = tqdm(self.dataloader)
+
+ def train_in_iter(self):
+ """
+ Train in one iter diffusion model method
+ """
+ # Initialize images and labels
+ images, labels, loss_list = None, None, []
+ for i, (images, labels) in enumerate(self.pbar):
+ # The images are all resized in dataloader
+ images = images.to(self.device)
+ # Generates a tensor of size images.shape[0] randomly sampled time steps
+ time = self.diffusion.sample_time_steps(images.shape[0]).to(self.device)
+ # Add noise, return as x value at time t and standard normal distribution
+ x_time, noise = self.diffusion.noise_images(x=images, time=time)
+ # Enable Automatic mixed precision training
+ # Automatic mixed precision training
+ # Note: If your Pytorch version > 2.4.1, with torch.amp.autocast("cuda", enabled=amp):
+ with autocast(enabled=self.amp):
+ # Unconditional training
+ if not self.conditional:
+ # Unconditional model prediction
+ predicted_noise = self.model(x_time, time)
+ # Conditional training, need to add labels
+ else:
+ labels = labels.to(self.device)
+ # Random unlabeled hard training, using only time steps and no class information
+ if np.random.random() < 0.1:
+ labels = None
+ # Conditional model prediction
+ predicted_noise = self.model(x_time, time, labels)
+ # To calculate the MSE loss
+ # You need to use the standard normal distribution of x at time t and the predicted noise
+ loss = self.loss_func(noise, predicted_noise)
+ # The optimizer clears the gradient of the model parameters
+ self.optimizer.zero_grad()
+ # Update loss and optimizer
+ # Fp16 + Fp32
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ # EMA
+ self.ema.step_ema(ema_model=self.ema_model, model=self.model)
+
+ # TensorBoard logging
+ self.pbar.set_postfix(MSE=loss.item())
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: MSE", scalar_value=loss.item(),
+ global_step=self.epoch * self.len_dataloader + i)
+ loss_list.append(loss.item())
+ # Loss per epoch
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Loss", scalar_value=sum(loss_list) / len(loss_list),
+ global_step=self.epoch)
+
+ def after_iter(self):
+ """
+ After training one iter diffusion model method
+ """
+ # Saving and validating models in the main process
+ if self.save_models:
+ # Saving model, set the checkpoint name
+ save_name = f"ckpt_{str(self.epoch).zfill(3)}"
+ # Init ckpt params
+ ckpt_model, ckpt_ema_model, ckpt_optimizer = None, None, None
+ if not self.conditional:
+ ckpt_model = self.model.state_dict()
+ ckpt_optimizer = self.optimizer.state_dict()
+ # Enable visualization
+ if self.vis:
+ # images.shape[0] is the number of images in the current batch
+ n = self.num_vis if self.num_vis > 0 else self.batch_size
+ sampled_images = self.diffusion.sample(model=self.model, n=n)
+ save_images(images=sampled_images,
+ path=os.path.join(self.results_vis_dir, f"{save_name}.{self.image_format}"))
+ else:
+ ckpt_model = self.model.state_dict()
+ ckpt_ema_model = self.ema_model.state_dict()
+ ckpt_optimizer = self.optimizer.state_dict()
+ # Enable visualization
+ if self.vis:
+ labels = torch.arange(self.num_classes).long().to(self.device)
+ n = self.num_vis if self.num_vis > 0 else len(labels)
+ sampled_images = self.diffusion.sample(model=self.model, n=n, labels=labels,
+ cfg_scale=self.cfg_scale)
+ ema_sampled_images = self.diffusion.sample(model=self.ema_model, n=n, labels=labels,
+ cfg_scale=self.cfg_scale)
+ # This is a method to display the results of each model during training and can be commented out
+ # plot_images(images=sampled_images)
+ save_images(images=sampled_images,
+ path=os.path.join(self.results_vis_dir, f"{save_name}.{self.image_format}"))
+ save_images(images=ema_sampled_images,
+ path=os.path.join(self.results_vis_dir, f"ema_{save_name}.{self.image_format}"))
+ # Save checkpoint
+ save_ckpt(epoch=self.epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model,
+ ckpt_optimizer=ckpt_optimizer, results_dir=self.results_dir,
+ save_model_interval=self.save_model_interval,
+ save_model_interval_epochs=self.save_model_interval_epochs,
+ start_model_interval=self.start_model_interval, conditional=self.conditional,
+ image_size=self.image_size, sample=self.sample, network=self.network, act=self.act,
+ num_classes=self.num_classes)
+ logger.info(msg=f"[{self.device}]: Finish epoch {self.epoch}:")
+
+ # Synchronization during distributed training
+ if self.distributed:
+ logger.info(msg=f"[{self.device}]: Synchronization during distributed training.")
+ dist.barrier()
+
+ def after_train(self):
+ """
+ After training diffusion model method
+ """
+ logger.info(msg=f"[{self.device}]: Finish training.")
+ logger.info(msg="[Note]: If you want to evaluate model quality, use 'FID_calculator.py' to evaluate.")
+
+ # Clean up the distributed environment
+ if self.distributed:
+ dist.destroy_process_group()
diff --git a/model/trainers/sr.py b/model/trainers/sr.py
new file mode 100644
index 0000000..8942ab2
--- /dev/null
+++ b/model/trainers/sr.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+"""
+ @Date : 2024/9/18 13:09
+ @Author : chairc
+ @Site : https://github.com/chairc
+"""
+import os
+import sys
+import copy
+import logging
+import time
+
+import coloredlogs
+import torch
+
+from torch import nn as nn
+from torch import distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+from torch.cuda.amp import autocast
+from tqdm import tqdm
+
+sys.path.append(os.path.dirname(sys.path[0]))
+from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
+from model.modules.ema import EMA
+from model.trainers.base import Trainer
+from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \
+ lr_initializer, amp_initializer, loss_initializer
+from utils.utils import save_images, setup_logging, save_train_logging, check_and_create_dir
+from utils.check import check_is_distributed
+from utils.checkpoint import load_ckpt, save_ckpt
+from utils.metrics import compute_psnr, compute_ssim
+from sr.interface import post_image
+from sr.dataset import get_sr_dataset
+
+logger = logging.getLogger(__name__)
+coloredlogs.install(level="INFO")
+
+
+class SRTrainer(Trainer):
+ """
+ Super resolution trainer
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialize super resolution trainer
+ :param kwargs: Parameters of trainer
+ """
+ super(SRTrainer, self).__init__(**kwargs)
+ # Can be set input params
+ # Run name
+ self.run_name = self.check_args_and_kwargs(kwarg="run_name", default="sr")
+ # Datasets
+ self.train_dataset_path = self.check_args_and_kwargs(kwarg="train_dataset_path", default="")
+ self.val_dataset_path = self.check_args_and_kwargs(kwarg="val_dataset_path", default="")
+ # Evaluate quickly
+ self.quick_eval = self.check_args_and_kwargs(kwarg="quick_eval", default=False)
+
+ # Default params
+ self.train_dataloader = None
+ self.val_dataloader = None
+ self.len_train_dataloader = None
+ self.len_val_dataloader = None
+ self.save_val_vis_dir = None
+ self.best_ssim = 0
+ self.best_psnr = 0
+ self.avg_val_loss = 0
+ self.avg_ssim = 0
+ self.avg_psnr = 0
+
+ def before_train(self):
+ """
+ Before training super resolution model method
+ """
+ logger.info(msg=f"[{self.rank}]: Start super resolution model training")
+ logger.info(msg=f"[{self.rank}]: Input params: {self.args}")
+ # Step1: Set path and create log
+ # Create data logging path
+ self.results_logging = setup_logging(save_path=self.result_path, run_name=self.run_name)
+ self.results_dir = self.results_logging[1]
+ self.results_vis_dir = self.results_logging[2]
+ self.results_tb_dir = self.results_logging[3]
+ # Train log
+ self.args = save_train_logging(arg=self.args, save_path=self.results_dir)
+
+ # Step2: Get the parameters of the initializer and args
+ # Initialize the seed
+ seed_initializer(seed_id=self.seed)
+ # Initialize and save the model identification bit
+ # Check here whether it is single-GPU training or multi-GPU training
+ self.save_models = True
+ # Whether to enable distributed training
+ if check_is_distributed(distributed=self.distributed):
+ self.distributed = True
+ # Set address and port
+ os.environ["MASTER_ADDR"] = MASTER_ADDR
+ os.environ["MASTER_PORT"] = MASTER_PORT
+ # The total number of processes is equal to the number of graphics cards
+ dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=self.rank,
+ world_size=self.world_size)
+ # Set device ID
+ self.device = device_initializer(device_id=self.rank, is_train=True)
+ # There may be random errors, using this function can reduce random errors in cudnn
+ # torch.backends.cudnn.deterministic = True
+ # Synchronization during distributed training
+ dist.barrier()
+ # If the distributed training is not the main GPU, the save model flag is False
+ if dist.get_rank() != self.main_gpu:
+ self.save_models = False
+ logger.info(msg=f"[{self.device}]: Successfully Use distributed training.")
+ else:
+ self.distributed = False
+ # Run device initializer
+ self.device = device_initializer(device_id=self.use_gpu, is_train=True)
+ logger.info(msg=f"[{self.device}]: Successfully Use normal training.")
+ # Dataloader
+ self.train_dataloader = get_sr_dataset(image_size=self.image_size, dataset_path=self.train_dataset_path,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers, distributed=self.distributed)
+ # Quick eval batch size
+ val_batch_size = self.batch_size if self.quick_eval else 1
+ self.val_dataloader = get_sr_dataset(image_size=self.image_size, dataset_path=self.val_dataset_path,
+ batch_size=val_batch_size,
+ num_workers=self.num_workers, distributed=self.distributed)
+ # Network
+ Network = sr_network_initializer(network=self.network, device=self.device)
+ # Model
+ self.model = Network(act=self.act).to(self.device)
+ # Distributed training
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(module=self.model, device_ids=[self.device],
+ find_unused_parameters=True)
+ # Model optimizer
+ self.optimizer = optimizer_initializer(model=self.model, optim=self.optim, init_lr=self.init_lr,
+ device=self.device)
+ # Resume training
+ if self.resume:
+ ckpt_path = None
+ # Determine which checkpoint to load
+ # 'start_epoch' is correct
+ if self.start_epoch is not None:
+ ckpt_path = os.path.join(self.results_dir, f"ckpt_{str(self.start_epoch - 1).zfill(3)}.pt")
+ # Parameter 'ckpt_path' is None in the train mode
+ if ckpt_path is None:
+ ckpt_path = os.path.join(self.results_dir, "ckpt_last.pt")
+ # The best model
+ ckpt_best_path = os.path.join(self.results_dir, "ckpt_best.pt")
+ # Get model state
+ self.start_epoch = load_ckpt(ckpt_path=ckpt_path, model=self.model, device=self.device,
+ optimizer=self.optimizer, is_distributed=self.distributed)
+ # Get best ssim and psnr
+ self.best_ssim, self.best_psnr = load_ckpt(ckpt_path=ckpt_best_path, device=self.device, ckpt_type="sr")
+ logger.info(msg=f"[{self.device}]: Successfully load resume model checkpoint.")
+ logger.info(msg=f"[{self.device}]: The start epoch is {self.start_epoch}, best ssim is {self.best_ssim}, "
+ f"best psnr is {self.best_psnr}.")
+ else:
+ # Pretrain mode
+ if self.pretrain:
+ load_ckpt(ckpt_path=self.pretrain_path, model=self.model, device=self.device, is_pretrain=self.pretrain,
+ is_distributed=self.distributed)
+ logger.info(msg=f"[{self.device}]: Successfully load pretrain model checkpoint.")
+ # Init
+ self.start_epoch, self.best_ssim, self.best_psnr = 0, 0, 0
+ logger.info(msg=f"[{self.device}]: The start epoch is {self.start_epoch}, best ssim is {self.best_ssim}, "
+ f"best psnr is {self.best_psnr}.")
+ # Set harf-precision
+ self.scaler = amp_initializer(amp=self.amp, device=self.device)
+ # Loss function
+ self.loss_func = loss_initializer(loss_name=self.loss_name, device=self.device)
+ # Tensorboard
+ self.tb_logger = SummaryWriter(log_dir=self.results_tb_dir)
+ # Number of dataset batches in the dataloader
+ self.len_train_dataloader = len(self.train_dataloader)
+ self.len_val_dataloader = len(self.val_dataloader)
+ # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
+ self.ema = EMA(beta=EMA_BETA)
+ # EMA model
+ self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
+
+ def train_in_epochs(self):
+ """
+ Train in epochs super resolution model method
+ """
+ logger.info(msg=f"[{self.device}]: Start training.")
+ # Start iterating
+ for self.epoch in range(self.start_epoch, self.epochs):
+ self.before_iter()
+ self.train_in_iter()
+ self.after_iter()
+
+ def before_iter(self):
+ """
+ Before training one iter super resolution model method
+ """
+ logger.info(msg=f"[{self.device}]: Start epoch {self.epoch}:")
+ # Set learning rate
+ current_lr = lr_initializer(lr_func=self.lr_func, optimizer=self.optimizer, epoch=self.epoch,
+ epochs=self.epochs,
+ init_lr=self.init_lr, device=self.device)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Current LR", scalar_value=current_lr, global_step=self.epoch)
+ # Create vis dir
+ self.save_val_vis_dir = os.path.join(self.results_vis_dir, str(self.epoch))
+ check_and_create_dir(self.save_val_vis_dir)
+
+ def train_in_iter(self):
+ """
+ Train in one iter super resolution model method
+ """
+ # Initialize images and labels
+ train_loss_list, val_loss_list, ssim_list, psnr_list = [], [], [], []
+ # Train
+ self.model.train()
+ logger.info(msg="Start train mode.")
+ train_pbar = tqdm(self.train_dataloader)
+ for i, (lr_images, hr_images) in enumerate(train_pbar):
+ # The images are all resized in train dataloader
+ lr_images = lr_images.to(self.device)
+ hr_images = hr_images.to(self.device)
+ # Enable Automatic mixed precision training
+ # Automatic mixed precision training
+ # Note: If your Pytorch version > 2.4.1, with torch.amp.autocast("cuda", enabled=amp):
+ with autocast(enabled=self.amp):
+ output = self.model(lr_images)
+ # To calculate the MSE loss
+ # You need to use the standard normal distribution of x at time t and the predicted noise
+ train_loss = self.loss_func(output, hr_images)
+ # The optimizer clears the gradient of the model parameters
+ self.optimizer.zero_grad()
+ # Update loss and optimizer
+ # Fp16 + Fp32
+ self.scaler.scale(train_loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ # EMA
+ self.ema.step_ema(ema_model=self.ema_model, model=self.model)
+
+ # TensorBoard logging
+ train_pbar.set_postfix(MSE=train_loss.item())
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss({self.loss_func})",
+ scalar_value=train_loss.item(),
+ global_step=self.epoch * self.len_train_dataloader + i)
+ train_loss_list.append(train_loss.item())
+ # Loss per epoch
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss",
+ scalar_value=sum(train_loss_list) / len(train_loss_list),
+ global_step=self.epoch)
+ logger.info(msg="Finish train mode.")
+
+ # Val
+ self.model.eval()
+ logger.info(msg="Start val mode.")
+ val_pbar = tqdm(self.val_dataloader)
+ for i, (lr_images, hr_images) in enumerate(val_pbar):
+ # The images are all resized in val dataloader
+ lr_images = lr_images.to(self.device)
+ hr_images = hr_images.to(self.device)
+ # Enable Automatic mixed precision training
+ # Automatic mixed precision training
+ with torch.no_grad():
+ output = self.model(lr_images)
+ # To calculate the MSE loss
+ # You need to use the standard normal distribution of x at time t and the predicted noise
+ val_loss = self.loss_func(output, hr_images)
+ # The optimizer clears the gradient of the model parameters
+ self.optimizer.zero_grad()
+
+ # TensorBoard logging
+ val_pbar.set_postfix(MSE=val_loss.item())
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(),
+ global_step=self.epoch * self.len_val_dataloader + i)
+ val_loss_list.append(val_loss.item())
+
+ # Metric
+ ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
+ psnr_res = compute_psnr(mse=val_loss.item())
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
+ global_step=self.epoch * self.len_val_dataloader + i)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
+ global_step=self.epoch * self.len_val_dataloader + i)
+ ssim_list.append(ssim_res)
+ psnr_list.append(psnr_res)
+
+ # Save super resolution image and high resolution image
+ lr_images = post_image(lr_images, device=self.device)
+ sr_images = post_image(output, device=self.device)
+ hr_images = post_image(hr_images, device=self.device)
+ image_name = time.time()
+ for lr_index, lr_image in enumerate(lr_images):
+ save_images(images=lr_image,
+ path=os.path.join(self.save_val_vis_dir,
+ f"{i}_{image_name}_{lr_index}_lr.{self.image_format}"))
+ for sr_index, sr_image in enumerate(sr_images):
+ save_images(images=sr_image,
+ path=os.path.join(self.save_val_vis_dir,
+ f"{i}_{image_name}_{sr_index}_sr.{self.image_format}"))
+ for hr_index, hr_image in enumerate(hr_images):
+ save_images(images=hr_image,
+ path=os.path.join(self.save_val_vis_dir,
+ f"{i}_{image_name}_{hr_index}_hr.{self.image_format}"))
+ # Loss, ssim and psnr per epoch
+ self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
+ self.avg_ssim = sum(ssim_list) / len(ssim_list)
+ self.avg_psnr = sum(psnr_list) / len(psnr_list)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss", scalar_value=self.avg_val_loss,
+ global_step=self.epoch)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch)
+ self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch)
+ logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
+ logger.info(msg="Finish val mode.")
+
+ def after_iter(self):
+ """
+ After training one iter diffusion model method
+ """
+ # Saving and validating models in the main process
+ if self.save_models:
+ # Saving model, set the checkpoint name
+ save_name = f"ckpt_{str(self.epoch).zfill(3)}"
+ # Init ckpt params
+ ckpt_model = self.model.state_dict()
+ ckpt_ema_model = self.ema_model.state_dict()
+ ckpt_optimizer = self.optimizer.state_dict()
+ # Save the best model
+ if (self.avg_ssim > self.best_ssim) and (self.avg_psnr > self.best_psnr):
+ is_best = True
+ self.best_ssim = self.avg_ssim
+ self.best_psnr = self.avg_psnr
+ else:
+ is_best = False
+ # Save checkpoint
+ save_ckpt(epoch=self.epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model,
+ ckpt_optimizer=ckpt_optimizer, results_dir=self.results_dir,
+ save_model_interval=self.save_model_interval, start_model_interval=self.start_model_interval,
+ save_model_interval_epochs=self.save_model_interval_epochs, image_size=self.image_size,
+ network=self.network, act=self.act, is_sr=True, is_best=is_best, ssim=self.avg_ssim,
+ psnr=self.avg_psnr)
+ logger.info(msg=f"[{self.device}]: Finish epoch {self.epoch}:")
+
+ # Synchronization during distributed training
+ if self.distributed:
+ logger.info(msg=f"[{self.device}]: Synchronization during distributed training.")
+ dist.barrier()
+
+ def after_train(self):
+ """
+ After training super resolution model method
+ """
+ logger.info(msg=f"[{self.device}]: Finish training.")
+
+ # Clean up the distributed environment
+ if self.distributed:
+ dist.destroy_process_group()
diff --git a/sr/train.py b/sr/train.py
index 4605694..3612ceb 100644
--- a/sr/train.py
+++ b/sr/train.py
@@ -8,318 +8,22 @@
import os
import sys
import argparse
-import copy
import logging
-import time
import coloredlogs
import torch
-from torch import nn as nn
-from torch import distributed as dist
from torch import multiprocessing as mp
-from torch.utils.tensorboard import SummaryWriter
-from torch.cuda.amp import autocast
-from tqdm import tqdm
sys.path.append(os.path.dirname(sys.path[0]))
-from config.choices import sr_network_choices, optim_choices
-from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
+from config.choices import sr_network_choices, optim_choices, sr_loss_func_choices, image_format_choices
from config.version import get_version_banner
-from model.modules.ema import EMA
-from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \
- lr_initializer, amp_initializer, loss_initializer
-from utils.utils import save_images, setup_logging, save_train_logging, check_and_create_dir
-from utils.checkpoint import load_ckpt, save_ckpt
-from utils.metrics import compute_psnr, compute_ssim
-from sr.interface import post_image
-from sr.dataset import get_sr_dataset
+from model.trainers.sr import SRTrainer
logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
-def train(rank=None, args=None):
- """
- Training
- :param rank: Device id
- :param args: Input parameters
- :return: None
- """
- logger.info(msg=f"[{rank}]: Input params: {args}")
- # Initialize the seed
- seed_initializer(seed_id=args.seed)
- # Network
- network = args.network
- # Run name
- run_name = args.run_name
- # Input image size
- image_size = args.image_size
- # Datasets
- train_dataset_path = args.train_dataset_path
- val_dataset_path = args.val_dataset_path
- # Batch size
- batch_size = args.batch_size
- # Number of workers
- num_workers = args.num_workers
- # Select optimizer
- optim = args.optim
- # Loss function only mse
- loss_func = "mse"
- # Select activation function
- act = args.act
- # Learning rate
- init_lr = args.lr
- # Learning rate function
- lr_func = args.lr_func
- # Initialize and save the model identification bit
- # Check here whether it is single-GPU training or multi-GPU training
- save_models = True
- # Whether to enable distributed training
- if args.distributed and torch.cuda.device_count() > 1 and torch.cuda.is_available():
- distributed = True
- world_size = args.world_size
- # Set address and port
- os.environ["MASTER_ADDR"] = MASTER_ADDR
- os.environ["MASTER_PORT"] = MASTER_PORT
- # The total number of processes is equal to the number of graphics cards
- dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank,
- world_size=world_size)
- # Set device ID
- device = device_initializer(device_id=rank, is_train=True)
- # There may be random errors, using this function can reduce random errors in cudnn
- # torch.backends.cudnn.deterministic = True
- # Synchronization during distributed training
- dist.barrier()
- # If the distributed training is not the main GPU, the save model flag is False
- if dist.get_rank() != args.main_gpu:
- save_models = False
- logger.info(msg=f"[{device}]: Successfully Use distributed training.")
- else:
- distributed = False
- # Run device initializer
- device = device_initializer(device_id=args.use_gpu, is_train=True)
- logger.info(msg=f"[{device}]: Successfully Use normal training.")
- # Whether to enable automatic mixed precision training
- amp = args.amp
- # Save model interval
- save_model_interval = args.save_model_interval
- # Save model interval in the start epoch
- start_model_interval = args.start_model_interval
- # Saving path
- result_path = args.result_path
- # Create data logging path
- results_logging = setup_logging(save_path=result_path, run_name=run_name)
- results_dir = results_logging[1]
- results_vis_dir = results_logging[2]
- results_tb_dir = results_logging[3]
- # Dataloader
- train_dataloader = get_sr_dataset(image_size=image_size, dataset_path=train_dataset_path, batch_size=batch_size,
- num_workers=num_workers, distributed=distributed)
- # Quick eval batch size
- val_batch_size = batch_size if args.quick_eval else 1
- val_dataloader = get_sr_dataset(image_size=image_size, dataset_path=val_dataset_path, batch_size=val_batch_size,
- num_workers=num_workers, distributed=distributed)
- # Resume training
- resume = args.resume
- # Pretrain
- pretrain = args.pretrain
- # Network
- Network = sr_network_initializer(network=network, device=device)
- # Model
- model = Network(act=act).to(device)
- # Distributed training
- if distributed:
- model = nn.parallel.DistributedDataParallel(module=model, device_ids=[device], find_unused_parameters=True)
- # Model optimizer
- optimizer = optimizer_initializer(model=model, optim=optim, init_lr=init_lr, device=device)
- # Resume training
- if resume:
- ckpt_path = None
- start_epoch = args.start_epoch
- # Determine which checkpoint to load
- # 'start_epoch' is correct
- if start_epoch is not None:
- ckpt_path = os.path.join(results_dir, f"ckpt_{str(start_epoch - 1).zfill(3)}.pt")
- # Parameter 'ckpt_path' is None in the train mode
- if ckpt_path is None:
- ckpt_path = os.path.join(results_dir, "ckpt_last.pt")
- # The best model
- ckpt_best_path = os.path.join(results_dir, "ckpt_best.pt")
- # Get model state
- start_epoch = load_ckpt(ckpt_path=ckpt_path, model=model, device=device, optimizer=optimizer,
- is_distributed=distributed)
- # Get best ssim and psnr
- best_ssim, best_psnr = load_ckpt(ckpt_path=ckpt_best_path, device=device, ckpt_type="sr")
- logger.info(msg=f"[{device}]: Successfully load resume model checkpoint.")
- logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, "
- f"best psnr is {best_psnr}.")
- else:
- # Pretrain mode
- if pretrain:
- pretrain_path = args.pretrain_path
- load_ckpt(ckpt_path=pretrain_path, model=model, device=device, is_pretrain=pretrain,
- is_distributed=distributed)
- logger.info(msg=f"[{device}]: Successfully load pretrain model checkpoint.")
- # Init
- start_epoch, best_ssim, best_psnr = 0, 0, 0
- logger.info(msg=f"[{device}]: The start epoch is {start_epoch}, best ssim is {best_ssim}, "
- f"best psnr is {best_psnr}.")
- # Set harf-precision
- scaler = amp_initializer(amp=amp, device=device)
- # Loss function
- loss_func = loss_initializer(loss_name=loss_func, device=device)
- # Tensorboard
- tb_logger = SummaryWriter(log_dir=results_tb_dir)
- # Train log
- save_train_logging(args, results_dir)
- # Number of dataset batches in the dataloader
- len_train_dataloader = len(train_dataloader)
- len_val_dataloader = len(val_dataloader)
- # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
- ema = EMA(beta=EMA_BETA)
- # EMA model
- ema_model = copy.deepcopy(model).eval().requires_grad_(False)
-
- logger.info(msg=f"[{device}]: Start training.")
- # Start iterating
- for epoch in range(start_epoch, args.epochs):
- logger.info(msg=f"[{device}]: Start epoch {epoch}:")
- # Set learning rate
- current_lr = lr_initializer(lr_func=lr_func, optimizer=optimizer, epoch=epoch, epochs=args.epochs,
- init_lr=init_lr, device=device)
- tb_logger.add_scalar(tag=f"[{device}]: Current LR", scalar_value=current_lr, global_step=epoch)
- # Create vis dir
- save_val_vis_dir = os.path.join(results_vis_dir, str(epoch))
- check_and_create_dir(save_val_vis_dir)
- # Initialize images and labels
- train_loss_list, val_loss_list, ssim_list, psnr_list = [], [], [], []
-
- # Train
- model.train()
- logger.info(msg="Start train mode.")
- train_pbar = tqdm(train_dataloader)
- for i, (lr_images, hr_images) in enumerate(train_pbar):
- # The images are all resized in train dataloader
- lr_images = lr_images.to(device)
- hr_images = hr_images.to(device)
- # Enable Automatic mixed precision training
- # Automatic mixed precision training
- # Note: If your Pytorch version > 2.4.1, with torch.amp.autocast("cuda", enabled=amp):
- with autocast(enabled=amp):
- output = model(lr_images)
- # To calculate the MSE loss
- # You need to use the standard normal distribution of x at time t and the predicted noise
- train_loss = loss_func(output, hr_images)
- # The optimizer clears the gradient of the model parameters
- optimizer.zero_grad()
- # Update loss and optimizer
- # Fp16 + Fp32
- scaler.scale(train_loss).backward()
- scaler.step(optimizer)
- scaler.update()
- # EMA
- ema.step_ema(ema_model=ema_model, model=model)
-
- # TensorBoard logging
- train_pbar.set_postfix(MSE=train_loss.item())
- tb_logger.add_scalar(tag=f"[{device}]: Train loss({loss_func})", scalar_value=train_loss.item(),
- global_step=epoch * len_train_dataloader + i)
- train_loss_list.append(train_loss.item())
- # Loss per epoch
- tb_logger.add_scalar(tag=f"[{device}]: Train loss", scalar_value=sum(train_loss_list) / len(train_loss_list),
- global_step=epoch)
- logger.info(msg="Finish train mode.")
-
- # Val
- model.eval()
- logger.info(msg="Start val mode.")
- val_pbar = tqdm(val_dataloader)
- for i, (lr_images, hr_images) in enumerate(val_pbar):
- # The images are all resized in val dataloader
- lr_images = lr_images.to(device)
- hr_images = hr_images.to(device)
- # Enable Automatic mixed precision training
- # Automatic mixed precision training
- with torch.no_grad():
- output = model(lr_images)
- # To calculate the MSE loss
- # You need to use the standard normal distribution of x at time t and the predicted noise
- val_loss = loss_func(output, hr_images)
- # The optimizer clears the gradient of the model parameters
- optimizer.zero_grad()
-
- # TensorBoard logging
- val_pbar.set_postfix(MSE=val_loss.item())
- tb_logger.add_scalar(tag=f"[{device}]: Val loss({loss_func})", scalar_value=val_loss.item(),
- global_step=epoch * len_val_dataloader + i)
- val_loss_list.append(val_loss.item())
-
- # Metric
- ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
- psnr_res = compute_psnr(mse=val_loss.item())
- tb_logger.add_scalar(tag=f"[{device}]: SSIM({loss_func})", scalar_value=ssim_res,
- global_step=epoch * len_val_dataloader + i)
- tb_logger.add_scalar(tag=f"[{device}]: PSNR({loss_func})", scalar_value=psnr_res,
- global_step=epoch * len_val_dataloader + i)
- ssim_list.append(ssim_res)
- psnr_list.append(psnr_res)
-
- # Save super resolution image and high resolution image
- lr_images = post_image(lr_images, device=device)
- sr_images = post_image(output, device=device)
- hr_images = post_image(hr_images, device=device)
- image_name = time.time()
- for lr_index, lr_image in enumerate(lr_images):
- save_images(images=lr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{lr_index}_lr.jpg"))
- for sr_index, sr_image in enumerate(sr_images):
- save_images(images=sr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{sr_index}_sr.jpg"))
- for hr_index, hr_image in enumerate(hr_images):
- save_images(images=hr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{hr_index}_hr.jpg"))
- # Loss, ssim and psnr per epoch
- avg_val_loss = sum(val_loss_list) / len(val_loss_list)
- avg_ssim = sum(ssim_list) / len(ssim_list)
- avg_psnr = sum(psnr_list) / len(psnr_list)
- tb_logger.add_scalar(tag=f"[{device}]: Val loss", scalar_value=avg_val_loss, global_step=epoch)
- tb_logger.add_scalar(tag=f"[{device}]: Avg ssim", scalar_value=avg_ssim, global_step=epoch)
- tb_logger.add_scalar(tag=f"[{device}]: Avg psnr", scalar_value=avg_psnr, global_step=epoch)
- logger.info(f"Val loss: {avg_val_loss}, SSIM: {avg_ssim}, PSNR: {avg_psnr}")
- logger.info(msg="Finish val mode.")
-
- # Saving and validating models in the main process
- if save_models:
- # Saving model, set the checkpoint name
- save_name = f"ckpt_{str(epoch).zfill(3)}"
- # Init ckpt params
- ckpt_model = model.state_dict()
- ckpt_ema_model = ema_model.state_dict()
- ckpt_optimizer = optimizer.state_dict()
- # Save the best model
- if (avg_ssim > best_ssim) and (avg_psnr > best_psnr):
- is_best = True
- best_ssim = avg_ssim
- best_psnr = avg_psnr
- else:
- is_best = False
- # Save checkpoint
- save_ckpt(epoch=epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model,
- ckpt_optimizer=ckpt_optimizer, results_dir=results_dir, save_model_interval=save_model_interval,
- save_model_interval_epochs=None, start_model_interval=start_model_interval, image_size=image_size,
- network=network, act=act, is_sr=True, is_best=is_best, ssim=avg_ssim, psnr=avg_psnr)
- logger.info(msg=f"[{device}]: Finish epoch {epoch}:")
-
- # Synchronization during distributed training
- if distributed:
- logger.info(msg=f"[{device}]: Synchronization during distributed training.")
- dist.barrier()
-
- logger.info(msg=f"[{device}]: Finish training.")
-
- # Clean up the distributed environment
- if distributed:
- dist.destroy_process_group()
-
-
def main(args):
"""
Main function
@@ -328,12 +32,16 @@ def main(args):
"""
if args.distributed:
gpus = torch.cuda.device_count()
- mp.spawn(train, args=(args,), nprocs=gpus)
+ mp.spawn(SRTrainer(args=args).train, nprocs=gpus)
else:
- train(args=args)
+ SRTrainer(args=args).train()
-if __name__ == "__main__":
+def init_sr_train_args():
+ """
+ Init super resolution model training arguments
+ :return: args
+ """
# Training model parameters
# required: Must be set
# needed: Set as needed
@@ -375,6 +83,9 @@ def main(args):
# Set optimizer (needed)
# Option: adam/adamw/sgd
parser.add_argument("--optim", type=str, default="sgd", choices=optim_choices)
+ # Set loss function
+ # Option: mse only
+ parser.add_argument("--loss", type=str, default="mse", choices=sr_loss_func_choices)
# Set activation function (needed)
# Option: gelu/silu/relu/relu6/lrelu
parser.add_argument("--act", type=str, default="silu")
@@ -387,10 +98,16 @@ def main(args):
parser.add_argument("--result_path", type=str, default="/your/path/Diffusion-Model/results")
# Whether to save weight each training (recommend)
parser.add_argument("--save_model_interval", default=False, action="store_true")
+ # Save model interval and save it every X epochs (needed)
+ parser.add_argument("--save_model_interval_epochs", type=int, default=10)
# Start epoch for saving models (needed)
# This option saves disk space. If not set, the default is '-1'. If set,
# it starts saving models from the specified epoch. It needs to be used with '--save_model_interval'
parser.add_argument("--start_model_interval", type=int, default=-1)
+ # Generated image format
+ # Recommend to use png for better generation quality.
+ # Option: jpg/png
+ parser.add_argument("--image_format", type=str, default="png", choices=image_format_choices)
# Resume interrupted training (needed)
# 1. Set to 'True' to resume interrupted training and check if the parameter 'run_name' is correct.
# 2. Set the resume interrupted epoch number. (If not, we would select the last)
@@ -418,7 +135,11 @@ def main(args):
# The value of world size will correspond to the actual number of GPUs or distributed nodes being used
parser.add_argument("--world_size", type=int, default=2)
- args = parser.parse_args()
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = init_sr_train_args()
# Get version banner
get_version_banner()
main(args)
diff --git a/tools/train.py b/tools/train.py
index e32f565..d64cc4d 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -8,306 +8,22 @@
import os
import sys
import argparse
-import copy
import logging
import coloredlogs
-import numpy as np
import torch
-from torch import nn as nn
-from torch import distributed as dist
from torch import multiprocessing as mp
-from torch.utils.tensorboard import SummaryWriter
-from torch.cuda.amp import autocast
-from tqdm import tqdm
sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, noise_schedule_choices, parse_image_size_type, loss_func_choices
-from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from config.version import get_version_banner
-from model.modules.ema import EMA
-from utils.check import check_image_size
-from utils.dataset import get_dataset
-from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
- sample_initializer, lr_initializer, amp_initializer, loss_initializer, classes_initializer
-from utils.utils import plot_images, save_images, setup_logging, save_train_logging
-from utils.checkpoint import load_ckpt, save_ckpt
+from model.trainers.dm import DMTrainer
logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
-def train(rank=None, args=None):
- """
- Training
- :param rank: Device id
- :param args: Input parameters
- :return: None
- """
- # =================================Before training=================================
- # Output params to console
- logger.info(msg=f"[{rank}]: Input params: {args}")
- # Step1: Set path and create log
- # Saving path
- result_path = args.result_path
- # Run name
- run_name = args.run_name
- # Create data logging path
- results_logging = setup_logging(save_path=result_path, run_name=run_name)
- results_dir = results_logging[1]
- results_vis_dir = results_logging[2]
- results_tb_dir = results_logging[3]
- # Tensorboard
- tb_logger = SummaryWriter(log_dir=results_tb_dir)
- # Train log
- save_train_logging(arg=args, save_path=results_dir)
-
- # Step2: Get the parameters of the initializer and args
- # Initialize the seed
- seed_initializer(seed_id=args.seed)
- # Sample type
- sample = args.sample
- # Network
- network = args.network
- # Input image size
- image_size = check_image_size(image_size=args.image_size)
- # Select optimizer
- optim = args.optim
- # Loss function
- loss_name = args.loss
- # Select activation function
- act = args.act
- # Learning rate
- init_lr = args.lr
- # Learning rate function
- lr_func = args.lr_func
- # Batch size
- batch_size = args.batch_size
- # Number of workers
- num_workers = args.num_workers
- # Dataset path
- dataset_path = args.dataset_path
- # Number of classes
- num_classes = classes_initializer(dataset_path=dataset_path)
- # classifier-free guidance interpolation weight, users can better generate model effect
- cfg_scale = args.cfg_scale
- # Whether to enable conditional training
- conditional = args.conditional
- # Initialize and save the model identification bit
- # Check here whether it is single-GPU training or multi-GPU training
- save_models = True
- # Whether to enable distributed training
- if args.distributed and torch.cuda.device_count() > 1 and torch.cuda.is_available():
- distributed = True
- world_size = args.world_size
- # Set address and port
- os.environ["MASTER_ADDR"] = MASTER_ADDR
- os.environ["MASTER_PORT"] = MASTER_PORT
- # The total number of processes is equal to the number of graphics cards
- dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank,
- world_size=world_size)
- # Set device ID
- device = device_initializer(device_id=rank, is_train=True)
- # There may be random errors, using this function can reduce random errors in cudnn
- # torch.backends.cudnn.deterministic = True
- # Synchronization during distributed training
- dist.barrier()
- # If the distributed training is not the main GPU, the save model flag is False
- if dist.get_rank() != args.main_gpu:
- save_models = False
- logger.info(msg=f"[{device}]: Successfully Use distributed training.")
- else:
- distributed = False
- # Run device initializer
- device = device_initializer(device_id=args.use_gpu, is_train=True)
- logger.info(msg=f"[{device}]: Successfully Use normal training.")
- # Whether to enable automatic mixed precision training
- amp = args.amp
- # Save model interval
- save_model_interval = args.save_model_interval
- # Save model interval and save it every X epochs
- save_model_interval_epochs = args.save_model_interval_epochs
- # Save model interval in the start epoch
- start_model_interval = args.start_model_interval
- # Enable data visualization
- vis = args.vis
- # Number of visualization images generated
- num_vis = args.num_vis
- # Generated image format
- image_format = args.image_format
- # Noise schedule
- noise_schedule = args.noise_schedule
- # Resume training
- resume = args.resume
- # Pretrain
- pretrain = args.pretrain
-
- # =================================About model initializer=================================
- # Step3: Init model
- # Network
- Network = network_initializer(network=network, device=device)
- # Model
- if not conditional:
- model = Network(device=device, image_size=image_size, act=act).to(device)
- else:
- model = Network(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device)
- # Distributed training
- if distributed:
- model = nn.parallel.DistributedDataParallel(module=model, device_ids=[device], find_unused_parameters=True)
- # Model optimizer
- optimizer = optimizer_initializer(model=model, optim=optim, init_lr=init_lr, device=device)
- # Resume training
- if resume:
- ckpt_path = None
- start_epoch = args.start_epoch
- # Determine which checkpoint to load
- # 'start_epoch' is correct
- if start_epoch is not None:
- ckpt_path = os.path.join(results_dir, f"ckpt_{str(start_epoch - 1).zfill(3)}.pt")
- # Parameter 'ckpt_path' is None in the train mode
- if ckpt_path is None:
- ckpt_path = os.path.join(results_dir, "ckpt_last.pt")
- start_epoch = load_ckpt(ckpt_path=ckpt_path, model=model, device=device, optimizer=optimizer,
- is_distributed=distributed, conditional=conditional)
- logger.info(msg=f"[{device}]: Successfully load resume model checkpoint.")
- else:
- # Pretrain mode
- if pretrain:
- pretrain_path = args.pretrain_path
- load_ckpt(ckpt_path=pretrain_path, model=model, device=device, is_pretrain=pretrain,
- is_distributed=distributed, conditional=conditional)
- logger.info(msg=f"[{device}]: Successfully load pretrain model checkpoint.")
- start_epoch = 0
- # Set harf-precision
- scaler = amp_initializer(amp=amp, device=device)
- # Loss function
- loss_func = loss_initializer(loss_name=loss_name, device=device)
- # Initialize the diffusion model
- diffusion = sample_initializer(sample=sample, image_size=image_size, device=device, schedule_name=noise_schedule)
- # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
- ema = EMA(beta=EMA_BETA)
- # EMA model
- ema_model = copy.deepcopy(model).eval().requires_grad_(False)
-
- # =================================About data=================================
- # Step4: Set data
- # Dataloader
- dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
- num_workers=num_workers, distributed=distributed)
- # Number of dataset batches in the dataloader
- len_dataloader = len(dataloader)
-
- # =================================Training=================================
- # Step5: Training
- logger.info(msg=f"[{device}]: Start training.")
- # Start iterating
- for epoch in range(start_epoch, args.epochs):
- logger.info(msg=f"[{device}]: Start epoch {epoch}:")
- # Set learning rate
- current_lr = lr_initializer(lr_func=lr_func, optimizer=optimizer, epoch=epoch, epochs=args.epochs,
- init_lr=init_lr, device=device)
- tb_logger.add_scalar(tag=f"[{device}]: Current LR", scalar_value=current_lr, global_step=epoch)
- pbar = tqdm(dataloader)
- # Initialize images and labels
- images, labels, loss_list = None, None, []
- for i, (images, labels) in enumerate(pbar):
- # The images are all resized in dataloader
- images = images.to(device)
- # Generates a tensor of size images.shape[0] randomly sampled time steps
- time = diffusion.sample_time_steps(images.shape[0]).to(device)
- # Add noise, return as x value at time t and standard normal distribution
- x_time, noise = diffusion.noise_images(x=images, time=time)
- # Enable Automatic mixed precision training
- # Automatic mixed precision training
- # Note: If your Pytorch version > 2.4.1, with torch.amp.autocast("cuda", enabled=amp):
- with autocast(enabled=amp):
- # Unconditional training
- if not conditional:
- # Unconditional model prediction
- predicted_noise = model(x_time, time)
- # Conditional training, need to add labels
- else:
- labels = labels.to(device)
- # Random unlabeled hard training, using only time steps and no class information
- if np.random.random() < 0.1:
- labels = None
- # Conditional model prediction
- predicted_noise = model(x_time, time, labels)
- # To calculate the MSE loss
- # You need to use the standard normal distribution of x at time t and the predicted noise
- loss = loss_func(noise, predicted_noise)
- # The optimizer clears the gradient of the model parameters
- optimizer.zero_grad()
- # Update loss and optimizer
- # Fp16 + Fp32
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- # EMA
- ema.step_ema(ema_model=ema_model, model=model)
-
- # TensorBoard logging
- pbar.set_postfix(MSE=loss.item())
- tb_logger.add_scalar(tag=f"[{device}]: MSE", scalar_value=loss.item(),
- global_step=epoch * len_dataloader + i)
- loss_list.append(loss.item())
- # Loss per epoch
- tb_logger.add_scalar(tag=f"[{device}]: Loss", scalar_value=sum(loss_list) / len(loss_list), global_step=epoch)
-
- # Saving and validating models in the main process
- if save_models:
- # Saving model, set the checkpoint name
- save_name = f"ckpt_{str(epoch).zfill(3)}"
- # Init ckpt params
- ckpt_model, ckpt_ema_model, ckpt_optimizer = None, None, None
- if not conditional:
- ckpt_model = model.state_dict()
- ckpt_optimizer = optimizer.state_dict()
- # Enable visualization
- if vis:
- # images.shape[0] is the number of images in the current batch
- n = num_vis if num_vis > 0 else batch_size
- sampled_images = diffusion.sample(model=model, n=n)
- save_images(images=sampled_images,
- path=os.path.join(results_vis_dir, f"{save_name}.{image_format}"))
- else:
- ckpt_model = model.state_dict()
- ckpt_ema_model = ema_model.state_dict()
- ckpt_optimizer = optimizer.state_dict()
- # Enable visualization
- if vis:
- labels = torch.arange(num_classes).long().to(device)
- n = num_vis if num_vis > 0 else len(labels)
- sampled_images = diffusion.sample(model=model, n=n, labels=labels, cfg_scale=cfg_scale)
- ema_sampled_images = diffusion.sample(model=ema_model, n=n, labels=labels, cfg_scale=cfg_scale)
- # This is a method to display the results of each model during training and can be commented out
- # plot_images(images=sampled_images)
- save_images(images=sampled_images,
- path=os.path.join(results_vis_dir, f"{save_name}.{image_format}"))
- save_images(images=ema_sampled_images,
- path=os.path.join(results_vis_dir, f"ema_{save_name}.{image_format}"))
- # Save checkpoint
- save_ckpt(epoch=epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model,
- ckpt_optimizer=ckpt_optimizer, results_dir=results_dir, save_model_interval=save_model_interval,
- save_model_interval_epochs=save_model_interval_epochs,
- start_model_interval=start_model_interval, conditional=conditional, image_size=image_size,
- sample=sample, network=network, act=act, num_classes=num_classes)
- logger.info(msg=f"[{device}]: Finish epoch {epoch}:")
-
- # Synchronization during distributed training
- if distributed:
- logger.info(msg=f"[{device}]: Synchronization during distributed training.")
- dist.barrier()
-
- logger.info(msg=f"[{device}]: Finish training.")
- logger.info(msg="[Note]: If you want to evaluate model quality, use 'FID_calculator.py' to evaluate.")
-
- # Clean up the distributed environment
- if distributed:
- dist.destroy_process_group()
-
-
def main(args):
"""
Main function
@@ -316,12 +32,16 @@ def main(args):
"""
if args.distributed:
gpus = torch.cuda.device_count()
- mp.spawn(train, args=(args,), nprocs=gpus)
+ mp.spawn(DMTrainer(args=args).train, nprocs=gpus)
else:
- train(args=args)
+ DMTrainer(args=args).train()
def init_train_args():
+ """
+ Init diffusion model training arguments
+ :return: args
+ """
# Training model parameters
# required: Must be set
# needed: Set as needed
@@ -345,7 +65,7 @@ def init_train_args():
# File name for initializing the model (required)
parser.add_argument("--run_name", type=str, default="df")
# Total epoch for training (required)
- parser.add_argument("--epochs", type=int, default=3)
+ parser.add_argument("--epochs", type=int, default=300)
# Batch size for training (required)
parser.add_argument("--batch_size", type=int, default=2)
# Number of sub-processes used for data loading (needed)
@@ -390,7 +110,7 @@ def init_train_args():
# it starts saving models from the specified epoch. It needs to be used with '--save_model_interval'
parser.add_argument("--start_model_interval", type=int, default=-1)
# Enable visualization of dataset information for model selection based on visualization (recommend)
- parser.add_argument("--vis", default=False, action="store_true")
+ parser.add_argument("--vis", default=True, action="store_true")
# Number of visualization images generated (recommend)
# If not filled, the default is the number of image classes (unconditional) or images.shape[0] (conditional)
parser.add_argument("--num_vis", type=int, default=-1)
diff --git a/utils/check.py b/utils/check.py
index 517b664..58dd010 100644
--- a/utils/check.py
+++ b/utils/check.py
@@ -104,3 +104,14 @@ def check_pretrain_path(pretrain_path):
if pretrain_path is None or not os.path.exists(pretrain_path):
return True
return False
+
+
+def check_is_distributed(distributed):
+ """
+ Check the distributed is valid
+ :param distributed: Distributed
+ :return: Boolean
+ """
+ if distributed and torch.cuda.device_count() > 1 and torch.cuda.is_available():
+ return True
+ return False