The project is based on the paper DiffMIC: Dual-Guidance Diffusion Network for Medical Image Classification.
Run the code: python3 ./main.py
main.py
calls the modules in src/
. The program loads a saved DCG checkpoint saved_dcg.pth
if it exists. It trains and saves a DCG checkpoint otherwise. Similarly, for diffusion, the program loads a saved diffusion checkpoint saved_diff.pth
if it exists. It trains and saves a diffusion checkpoint otherwise. After loading the DCG and diffusion models, the program runs inference on the test images and outputs the predicted classification [5 Classes: 0, 1, 2, 3, 4].
Download APTOS2019 dataset. Your dataset folder under "your_data_path" should be like:
dataset/aptos/
test/...
train/...
aptos_test.json
aptos_train.json
num_images
: Gives total number of images. 70:10:20 split implemented automatically for train:val:test. To change the train:val:test ratio, make the change in APTOSDataset
class. Minimum 1000 images are set to class DataProcessor
. If num_images
> 1000, it will automatically set that as number of images used.
train_batch_size
: Currently set to 25 as the machine can handle only that, however higher batch_size ~ 32 is recommended.
valid_batch_size
: Currently set to 25 as the machine can handle only that, however higher batch_size ~ 32 is recommended.
test_batch_size
: Currently set to 2 as the inference can happen only for that, however higher batch_size ~ 25 is recommended.
timestep
: Currently experimented with 500, 80, 60, 50. Optimally ~ 60 is recommended.
num_classes
: Selected as 5 for 5 classes - 0, 1, 2, 3, 4. Don't change it unless you change the dataset.
include_guidance
: Ensures that DCG priors are used in diffusion. true
indicates using DCG priors for diffusion, false
indicates not using DCG priors for diffusion.
weight
: For a detailed description look at Diffusion section of README.
dataloader.py
: Using DataProcessor.get_dataloaders()
, we get the train and test data, images and labels included. Train dataloader is used in DCG and Diffusion process for training. Test dataloader is used for evaluation of Diffusion.
transforms.py
: Transformation functions used in Data Pre-processing.
main.py
: class DCG is defined in this file. There is a function that train the DCG model, and another that loads the DCG model from a saved checkpoint also defined in this file.
networks.py
: The networks used in DCG are defined: Saliency Map, Global Network, Downsample Network, Local Network, and Attention Module.
utils.py
: Helper functions used in DCG are defined.
diffusion.py
: Contains code for forward and reverse diffusion.
A class for weighted_loss
has been implemented which takes weight
as a parameter that is passed from the diffusion parameters params["diffusion"]["weight"]
. There are three types of loss:
- Unweighted MMD loss (
weight=None
) - Weighted MMD loss with the inverse of number of images in each class (
weight=n
) - Weighted MMD loss with the square root of the inverse of number of images in each class (
weight=sqrt
)
Weights are calculated currently using total images. There are [1805, 370, 999, 193, 295] images for classes [0, 1, 2, 3, 4] respectively. Thus the weight is currently static and set to [1/1805, 1/370, 1/999, 1/193, 1/295]
.
unet_model.py
: Contains the UNet model used in diffusion.
metrics.py
: Contains the code for the classification metrics: Accuracy, confusion matrix, f1 score, and t-SNE.
Plots are generated in the plots/
directory. dcg_loss.png
is generated during the training of the DCG, and diffusion_loss.png
is generated during the training of the diffusion model. Confusion matrices for DCG and diffusion are saved as dcg_confusion.png
and diff_confusion.png
, respectively. t-SNE
plots are generated for different timesteps t1, t2, t3
in the diffusion parameters params["diffusion"]["t-sne"]
during the inference step.
The plot functions have a parameter mode
which can take a value of either dcg
or diffusion
, such that the corresponding plots for each model are generated. The default value is mode=dcg
. If the value of mode
is not dcg
or diffusion
, an error will be raised and the plots will not be generated.
project.log
: Logs during runtime are saved in this file.
report.txt
: Accuracy of the DCG, and the diffusion model are saved in this file. The f1 score, and the t-SNE values are also saved.
- Weights in Diffusion section can be made dynamic by finding the total number of images for each batch and then setting weights based on number of images of each class in the batch. To make sure we don't encounter a condition where a batch doesn't contain an image of a particular label, we can set the weight to 1 for that.
- Randomly flip images horizontally and vertically to increase size of the dataset for classes with low number of objects. This can help with imbalanced class issue.
- Images can be sampled evenly from each classes for test, validation and test datasets. This will ensure that the model is trained evenly on all classes.
Code is largely based on scott-yjyang/DiffMIC, XzwHan/CARD, CompVis/stable-diffusion, MedSegDiff, nyukat/GMIC