Skip to content

Image-to-Image Translation for Medical XAI: Counterfactual Insights into Pneumonia Detection

License

Notifications You must be signed in to change notification settings

anindyamitra2002/Counterfactual-Image-Generation-using-CycleGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Image-to-Image Translation for Medical XAI: Counterfactual Insights into Pneumonia Detection

Overview

Neural networks often operate as "black boxes," where their decision-making processes are opaque. This lack of transparency can be particularly problematic in high-stakes fields like healthcare. This project addresses the critical need for transparent and interpretable decision-making in medical image classification, where traditional methods often struggle to provide nuanced explanations. Recognizing the limitations of current approaches, particularly in conveying textural information, the project aims to pioneer a novel model-agnostic solution using adversarial image-to-image translation.

The primary objective is to generate realistic counterfactual images for image classifiers operating in medical contexts. By harnessing advanced techniques from adversarial learning, the project seeks to overcome the challenges associated with existing methods, such as saliency maps, which may not effectively capture the intricate nuances of decision-making processes rooted in texture and structure.

The proposed approach holds promise for enhancing the interpretability of deep learning models, especially in high-stakes domains like healthcare, where transparent decision-making is paramount. By providing realistic counterfactual explanations, the project aims to empower clinicians and stakeholders with insights into classifier decisions, thereby fostering trust and facilitating more informed medical interventions.

Workflow Diagram

Methodology

There are various approaches for solving image-to-image translation problems. Recent promising methods rely on adversarial learning. The original GANs approximate a function that transforms random noise vectors into images that follow the same probability distribution as a training dataset. This is achieved by combining a generator network $G$ and a discriminator network $D$. During training, the generator learns to create new images, while the discriminator learns to distinguish between real images from the training set and fake images generated by the generator. The objective of the two networks is defined as follows:

$L_{\text{original}}(G, D) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)} [\log (1 - D(G(z)))]$

Where:

  • $x$ are real images.
  • $z$ are random noise vectors.
  • $p_{\text{data}}(x)$ is the data distribution.
  • $p_{z}(z)$ is the noise distribution.

During training, the discriminator $D$ maximizes this objective function, while the generator $G$ tries to minimize it.

Image-to-Image Translation Networks

Various modified architectures have been developed to replace the random input noise vectors with images from another domain, enabling the transformation of images from one domain to another. These image-to-image translation networks commonly rely on paired datasets, where pairs of images differ only in the features defining the difference between the two domains. However, paired datasets are rare in practice.

CycleGAN

A solution to the problem of paired datasets was introduced by Zhu et al., who proposed the CycleGAN architecture. This architecture combines two GANs:

  • One GAN learns to translate images from domain $X$ to domain $Y$.
  • The other GAN learns to translate images from domain $Y$ to domain $X$.

The objective functions for the two GANs are defined as:

$L_{\text{GAN}}(G, D_Y, X, Y) = \mathbb{E}_{y \sim p_{\text{data}}(y)} [\log D_Y(y)] + \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log (1 - D_Y(G(x)))]$

$L_{\text{GAN}}(F, D_X, Y, X) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D_X(x)] + \mathbb{E}_{y \sim p_{\text{data}}(y)} [\log (1 - D_X(F(y)))]$

Where:

  • $G$ is the generator from $X$ to $Y$.
  • $F$ is the generator from $Y$ to $X$.
  • $D_Y$ and $D_X$ are the discriminators for domains $Y$ and $X$ respectively.

Loss Functions for CycleGAN

Adversarial Loss (GAN Loss)

The adversarial loss, also known as the GAN loss, is used to train the generators and discriminators in CycleGAN. It encourages the generators to produce images that are indistinguishable from the target domain images, while the discriminators aim to distinguish between real and generated images.

Objective: $ L_{\text{GAN}}(G, D, X, Y) = E_{y \sim p_{\text{data}}(y)}[\log D(y)] + E_{x \sim p_{\text{data}}(x)}[\log(1 - D(G(x)))]$

  • $ G$: Generator of the CycleGAN.
  • $ D$: Discriminator of the CycleGAN.
  • $ X$: Domain X (source domain).
  • $ Y$: Domain Y (target domain).
  • $ p_{\text{data}}(x)$: Probability distribution of real images from domain X.
  • $ p_{\text{data}}(y)$: Probability distribution of real images from domain Y.

Cycle-Consistency Loss

The cycle-consistency loss ensures that the reconstructed images maintain the essential characteristics of the original images after being translated back and forth between the two domains.

Objective: $ L_{\text{cycle}}(G, F) = E_{x \sim p_{\text{data}}(x)}[||F(G(x)) - x||_1] + E_{y \sim p_{\text{data}}(y)}[||G(F(y)) - y||_1]$

  • $ G$: Generator translating from domain X to domain Y.
  • $ F$: Generator translating from domain Y to domain X.

Identity Loss

The identity loss encourages the generators to preserve the content of the input image when translating it to the same domain. It ensures that the generator does not change images that already belong to the target domain.

Objective: $ L_{\text{identity}}(G, F) = E_{y \sim p_{\text{data}}(y)}[||G(y) - y||_1] + E_{x \sim p_{\text{data}}(x)}[||F(x) - x||_1]$

  • $ G$: Generator translating from domain X to domain Y.
  • $ F$: Generator translating from domain Y to domain X.

Counterfactual Loss

The counterfactual loss incorporates the predictions of a classifier into the CycleGAN's objective function. It penalizes the generators for producing translated images that are not classified as belonging to the respective counterfactual class by the classifier.

Objective: $ L_{\text{counter}}(G, F, C) = ||C(G(x)) - (0, 1)||_2^2 + ||C(F(y)) - (1, 0)||_2^2$

  • $ G$: Generator translating from domain X to domain Y.
  • $ F$: Generator translating from domain Y to domain X.
  • $ C$: Classifier predicting the probabilities of the input images belonging to each class.

Complete Generator Loss

The complete objective function of the CycleGAN with counterfactual loss is composed of the adversarial loss, cycle-consistency loss, identity loss, and counterfactual loss.

Objective: $ L(G, F, D_X, D_Y, C) = L_{\text{GAN}}(G, D_Y, X, Y) + L_{\text{GAN}}(F, D_X, Y, X) + \lambda L_{\text{cycle}}(G, F) + \mu L_{\text{identity}}(G, F) + \gamma L_{\text{counter}}(G, F, C)$

  • $ \lambda$: Cycle-consistency loss weight.
  • $ \mu$: Identity loss weight.
  • $ \gamma$: Counterfactual loss weight.
  • $ D_X$, $ D_Y$: Discriminators for domains X and Y, respectively.

How to Run the Repository

Installation

  1. Clone the repository and install the required packages:
python -m venv myenv
myenv\Scripts\activate
git clone https://github.com/anindyamitra2002/Counterfactual-Image-Generation-using-CycleGAN.git
cd Counterfactual-Image-Generation-using-CycleGAN
pip install -r requirements.txt

Running the Model

  1. Train the Classifier Model:

    python pipeline.py\
     --model_type "classifier"\
     --image_size 512\
     --batch_size 4\
     --epochs 100\
     --train_dir "/path/to/train/data"\
     --val_dir "/path/to/val/data"\
     --checkpoint_dir "./models"\
     --project "Your Project Name"\
     --job_name "classifier_training_job"
  2. Train the Generator and Discriminator:

    python pipeline.py\
     --model_type "cycle-gan"\
     --image_size 512\
     --batch_size 32\
     --epochs 50\
     --train_dir "/path/to/train/data"\
     --val_dir "/path/to/val/data"\
     --test_dir "/path/to/test/data"\
     --checkpoint_dir "./models"\
     --project "Your Project Name"\
     --job_name "cyclegan_training_job"\
     --classifier_path "/path/to/classifier/checkpoint"\
     --resume_ckpt_path "/path/to/cyclegan/checkpoint/for/resume/training"

Web Application

To run the web application:

gradio main.py

Evaluation Results

Evaluation result of training the efficientNet b1 classifier with 16739 images of RSNA Pneumonia detection Challage Dataset. Classifier model is trained on Kaggle T4 x 2 GPU

Model Name Runtime Epoch Train Accuracy Train Loss Val Accuracy Val Loss
efficientnet_b7 2h 54m 50s 12 0.962127 0.351027 0.847668 0.461976
convnext_tiny 2h 22m 31s 37 0.959379 0.354263 0.850149 0.460496
swin_t 1h 17m 30s 13 0.908124 0.402915 0.844195 0.464023
efficientnet_b1 1h 14m 23 0.968459 0.343911 0.850149 0.460599
efficientnet_b7 2h 21m 17 0.827360 0.479220 0.804830 0.499347
resnext101_64x4d 1h 48m 31s 13 0.799403 0.503710 0.780681 0.521838

Evaluation Result of generator and discriminator training which performed in lightning studio on L4 GPU:

Model Name Epochs Val Generator Loss Val Reconstruction Loss Val Class Loss Val Adversarial Loss Val Identity Loss Runtime
Attention Unet 50 1.638 0.01521 0.5008 0.9515 0.03367 4 hr 15 min 52 sec

Conclusion

Our project demonstrates that adversarial image-to-image translation is an effective tool for generating realistic counterfactual explanations in medical imaging. This approach not only enhances the interpretability and transparency of AI systems but also fosters trust and collaboration between AI and healthcare professionals. Future research will explore multi-class classification and further refine the counterfactual generation process.

Testing Web App Link

You can access the testing web application here.


Disclaimer: This repository is intended for research purposes only. The results and models should be validated by medical professionals before any clinical application.


Feel free to reach out for any questions or contributions!


Contact Information:

README.md created with ❤️ by [Anindya Mitra].

About

Image-to-Image Translation for Medical XAI: Counterfactual Insights into Pneumonia Detection

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published