PyTorch implementation of "Conditional diffusion model with spatial attention and latent embedding" [MICCAI 2024]
Diffusion models have been used extensively for high quality image and video generation tasks. In this paper, we propose a novel conditional diffusion model with spatial attention and latent embedding (cDAL) for medical image segmentation. In cDAL, a convolutional neural network (CNN) based discriminator is used at every time-step of the diffusion process to distinguish between the generated labels and the real ones. A spatial attention map is computed based on the features learned by the discriminator to help cDAL generate more accurate segmentation of discriminative regions in an input image. Additionally, we incorporated a random latent embedding into each layer of our model to significantly reduce the number of training and sampling time-steps, thereby making it much faster than other diffusion models for image segmentation
We trained cDAL on several datasets, including MoNuSeg2018, Chest-XRay(CXR) and Hippocampus.
We use the following commands on each dataset for training cDAL. Use parameters_monu.json
for MonuSeg and parameters_lung.json
for CXR.
To train the model for Hippocampus dataset, use this train_cDAL_hippo.py
. You can find corresponding parameters in the code.
To train either MoNuSeg or CXR, you should use train_cDal_monu_and_lung.py
. All necessary parameters are included in parameters_monu.json
and parameters_lung.json
. These files can be directly loaded into the code, or you can modify parameters in the code file.
Here you can find general website of the challenge, download the dataset train and test sets.
This is the link for Lung segmentation from Chest X-Ray dataset. To preprocess images, we followed the same standard.
In this is link you can find Hippocampus dataset. This dataset can be directly downloaded from this google drive link
We have already released pretrained checkpoints on MonuSeg and CXR in 'saved_models'.
Simply download the saved_models
directory to the code directory. Use parameters_monu.json
for MonuSeg and parameters_lung.json
for CXR.
After training, samples can be generated by calling sampling_monu_and_lung.py
for MoNuSeg and CXR datasets or sampling_hippo.py
for Hippocampus dataset.
Hippocamus uses metrics_hippo.py
file for evaluation since it should be processed based on One-hot encoding.
We evaluated the models with a single NVIDIA Quadro RTX 6000 GPU.
We use the MONAI implementation for Hippocampus dataset to process dataset and compute one-hot encoding. Aslo, we use DDGAN implemention for our difusion model and time-dependent discriminator.
Please check the LICENSE file.
Cite our paper using the following bibtex item: