Neural networks often operate as "black boxes," where their decision-making processes are opaque. This lack of transparency can be particularly problematic in high-stakes fields like healthcare. This project addresses the critical need for transparent and interpretable decision-making in medical image classification, where traditional methods often struggle to provide nuanced explanations. Recognizing the limitations of current approaches, particularly in conveying textural information, the project aims to pioneer a novel model-agnostic solution using adversarial image-to-image translation.
The primary objective is to generate realistic counterfactual images for image classifiers operating in medical contexts. By harnessing advanced techniques from adversarial learning, the project seeks to overcome the challenges associated with existing methods, such as saliency maps, which may not effectively capture the intricate nuances of decision-making processes rooted in texture and structure.
The proposed approach holds promise for enhancing the interpretability of deep learning models, especially in high-stakes domains like healthcare, where transparent decision-making is paramount. By providing realistic counterfactual explanations, the project aims to empower clinicians and stakeholders with insights into classifier decisions, thereby fostering trust and facilitating more informed medical interventions.
There are various approaches for solving image-to-image translation problems. Recent promising methods rely on adversarial learning. The original GANs approximate a function that transforms random noise vectors into images that follow the same probability distribution as a training dataset. This is achieved by combining a generator network
Where:
-
$x$ are real images. -
$z$ are random noise vectors. -
$p_{\text{data}}(x)$ is the data distribution. -
$p_{z}(z)$ is the noise distribution.
During training, the discriminator
Various modified architectures have been developed to replace the random input noise vectors with images from another domain, enabling the transformation of images from one domain to another. These image-to-image translation networks commonly rely on paired datasets, where pairs of images differ only in the features defining the difference between the two domains. However, paired datasets are rare in practice.
A solution to the problem of paired datasets was introduced by Zhu et al., who proposed the CycleGAN architecture. This architecture combines two GANs:
- One GAN learns to translate images from domain
$X$ to domain$Y$ . - The other GAN learns to translate images from domain
$Y$ to domain$X$ .
The objective functions for the two GANs are defined as:
Where:
-
$G$ is the generator from$X$ to$Y$ . -
$F$ is the generator from$Y$ to$X$ . -
$D_Y$ and$D_X$ are the discriminators for domains$Y$ and$X$ respectively.
The adversarial loss, also known as the GAN loss, is used to train the generators and discriminators in CycleGAN. It encourages the generators to produce images that are indistinguishable from the target domain images, while the discriminators aim to distinguish between real and generated images.
Objective:
-
$ G$ : Generator of the CycleGAN. -
$ D$ : Discriminator of the CycleGAN. -
$ X$ : Domain X (source domain). -
$ Y$ : Domain Y (target domain). -
$ p_{\text{data}}(x)$ : Probability distribution of real images from domain X. -
$ p_{\text{data}}(y)$ : Probability distribution of real images from domain Y.
The cycle-consistency loss ensures that the reconstructed images maintain the essential characteristics of the original images after being translated back and forth between the two domains.
Objective:
-
$ G$ : Generator translating from domain X to domain Y. -
$ F$ : Generator translating from domain Y to domain X.
The identity loss encourages the generators to preserve the content of the input image when translating it to the same domain. It ensures that the generator does not change images that already belong to the target domain.
Objective:
-
$ G$ : Generator translating from domain X to domain Y. -
$ F$ : Generator translating from domain Y to domain X.
The counterfactual loss incorporates the predictions of a classifier into the CycleGAN's objective function. It penalizes the generators for producing translated images that are not classified as belonging to the respective counterfactual class by the classifier.
Objective:
-
$ G$ : Generator translating from domain X to domain Y. -
$ F$ : Generator translating from domain Y to domain X. -
$ C$ : Classifier predicting the probabilities of the input images belonging to each class.
The complete objective function of the CycleGAN with counterfactual loss is composed of the adversarial loss, cycle-consistency loss, identity loss, and counterfactual loss.
Objective:
-
$ \lambda$ : Cycle-consistency loss weight. -
$ \mu$ : Identity loss weight. -
$ \gamma$ : Counterfactual loss weight. -
$ D_X$ ,$ D_Y$ : Discriminators for domains X and Y, respectively.
- Clone the repository and install the required packages:
python -m venv myenv
myenv\Scripts\activate
git clone https://github.com/anindyamitra2002/Counterfactual-Image-Generation-using-CycleGAN.git
cd Counterfactual-Image-Generation-using-CycleGAN
pip install -r requirements.txt
-
Train the Classifier Model:
python pipeline.py\ --model_type "classifier"\ --image_size 512\ --batch_size 4\ --epochs 100\ --train_dir "/path/to/train/data"\ --val_dir "/path/to/val/data"\ --checkpoint_dir "./models"\ --project "Your Project Name"\ --job_name "classifier_training_job"
-
Train the Generator and Discriminator:
python pipeline.py\ --model_type "cycle-gan"\ --image_size 512\ --batch_size 32\ --epochs 50\ --train_dir "/path/to/train/data"\ --val_dir "/path/to/val/data"\ --test_dir "/path/to/test/data"\ --checkpoint_dir "./models"\ --project "Your Project Name"\ --job_name "cyclegan_training_job"\ --classifier_path "/path/to/classifier/checkpoint"\ --resume_ckpt_path "/path/to/cyclegan/checkpoint/for/resume/training"
To run the web application:
gradio main.py
Evaluation result of training the efficientNet b1 classifier with 16739 images of RSNA Pneumonia detection Challage Dataset. Classifier model is trained on Kaggle T4 x 2 GPU
Model Name | Runtime | Epoch | Train Accuracy | Train Loss | Val Accuracy | Val Loss |
---|---|---|---|---|---|---|
efficientnet_b7 | 2h 54m 50s | 12 | 0.962127 | 0.351027 | 0.847668 | 0.461976 |
convnext_tiny | 2h 22m 31s | 37 | 0.959379 | 0.354263 | 0.850149 | 0.460496 |
swin_t | 1h 17m 30s | 13 | 0.908124 | 0.402915 | 0.844195 | 0.464023 |
efficientnet_b1 | 1h 14m | 23 | 0.968459 | 0.343911 | 0.850149 | 0.460599 |
efficientnet_b7 | 2h 21m | 17 | 0.827360 | 0.479220 | 0.804830 | 0.499347 |
resnext101_64x4d | 1h 48m 31s | 13 | 0.799403 | 0.503710 | 0.780681 | 0.521838 |
Evaluation Result of generator and discriminator training which performed in lightning studio on L4 GPU:
Model Name | Epochs | Val Generator Loss | Val Reconstruction Loss | Val Class Loss | Val Adversarial Loss | Val Identity Loss | Runtime |
---|---|---|---|---|---|---|---|
Attention Unet | 50 | 1.638 | 0.01521 | 0.5008 | 0.9515 | 0.03367 | 4 hr 15 min 52 sec |
Our project demonstrates that adversarial image-to-image translation is an effective tool for generating realistic counterfactual explanations in medical imaging. This approach not only enhances the interpretability and transparency of AI systems but also fosters trust and collaboration between AI and healthcare professionals. Future research will explore multi-class classification and further refine the counterfactual generation process.
You can access the testing web application here.
Disclaimer: This repository is intended for research purposes only. The results and models should be validated by medical professionals before any clinical application.
Feel free to reach out for any questions or contributions!
Contact Information:
- Email: anindyamitra2018@gmail.com
README.md created with ❤️ by [Anindya Mitra].