This code implements a graph neural network pipeline for whole slide image (WSI) classification using multi-stain embeddings and self-attention graph pooling.
The main entry point is main.py
. This handles loading the dataset, creating graphs, training and evaluation.
These are the main command-line arguments:
--dataset_name: Name of the dataset.
--PATH_patches: Path to the CSV file with patch file locations.
--embedding_vector_size: Size of the embedding vector.
--learning_rate: Learning rate for training.
--pooling_ratio: Pooling ratio used in the model.
--heads: Number of GAT (Graph Attention Network) heads.
--K: Number of nearest neighbors in the k-Nearest Neighbor Graph (k-NNG) created from WSI embeddings.
--train_fraction: Fraction of data used for training.
--num_epochs: Number of training epochs.
--n_classes: Number of classes for classification.
Here's an example of how to use the code:
python main.py --dataset_name RA --PATH_patches df_labels.csv --embedding_vector_size 1024 --learning_rate 0.0001 --pooling_ratio 0.7 --heads 2 --K 5 --train_fraction 0.7 --num_epochs 30 --n_classes 2
The code will automatically handle the following pipeline:
- Load the patches CSV file containing filenames and labels
- Split data into train and test sets based on patient IDs
- Create DataLoader for patches
- Extract 1024-dim VGG16 embeddings for each patch
- Build kNN graphs for each WSI using patch embeddings
- Create Graph DataLoader for WSI graphs
- Initialize GNN model (GAT + SAGPooling)
- Train model using Adam optimizer and cross-entropy loss
- Evaluate model on test set
- Save training curves, metrics and predictions
The trained model is saved to results/trained_model.pth.
To evaluate a trained model on new test data, set TRAIN=False
, TEST=True
in main.py
. This will run the model on the test set and print out evaluation metrics.
The MUSTANG pipeline is composed of:
• A - Segmentation: An automated segmentation step, where UNet is used to segment tissue areas on the WSIs. The user can use the trained weights provided on our GitHub repository or use their own.
• B - Patching: After segmentation, the tissue area is divided into patches at a size chosen by the user, which can be overlapping or non-overlapping.
• C - Feature extraction: Each image patch is passed through a VGG16 CNN feature extractor and embedded into a [1 × 1024] feature vector. All feature vectors from a given patient are aggregated into a matrix. The number of rows in the matrix will vary as each patient has a variable set of WSIs, each with their own dimensions.
• D - k-Nearest-Neighbour Graph: The matrix of feature vectors of each patient is used to create a sparse directed k-NNG using the Euclidean distance metric, with a default of k = 5. The attribute of each node corresponds to a [1 × 1024] feature vector. This graph is used as input to the GNN.
• E - Graph classification: The k-NNG is successively passed through four Graph Attention Network layers (GAT) and SAGPooling layers. The SAGPooling readouts from each layer are concatenated and passed through three MLP layers and finally classified.
• F - Prediction: A pathotype or diagnosis prediction is obtained at the patient-level.
If this code or article were useful to you, please consider citing:
@inproceedings{Gallagher-Syed_2023_BMVC,
author = {Amaya Gallagher-Syed and Luca Rossi and Felice Rivellese and Costantino Pitzalis and Myles Lewis and Michael Barnes and Gregory Slabaugh},
title = {Multi-Stain Self-Attention Graph Multiple Instance Learning Pipeline for Histopathology Whole Slide Images},
booktitle = {34th British Machine Vision Conference 2023, {BMVC} 2023, Aberdeen, UK, November 20-24, 2023},
publisher = {BMVA},
year = {2023},
url = {https://papers.bmvc2023.org/0789.pdf}
}
Amaya Gallagher-Syed, Luca Rossi, Felice Rivellese, Costantino Pitzalis, Myles Lewis, Michael Barnes, Gregory Slabaugh "Multi-Stain Self-Attention Graph Multiple Instance Learning Pipeline for Histopathology Whole Slide Images", British Machine Vision Conference, Aberdeen. 2023. https://doi.org/10.48550/arXiv.2309.10650.