Interpretability and Explainability pipeline for PyTorch
Esegeta (origin: ancient Greek) in Italian means interpreter of sacred texts
As the use of AI based apps and tools is incresing in critical domains like Medicine, Life support, Manufacturing and as well as in common house holds, its very important now that we strip off the black box nature of the machine learning models. The goal of this development is to create a complete package for testing different inerpretability and explainability methods on any given deep learning model. The package should work for both semantic segmemnation model as well as for classification models.
The implemented methods here are very resourse heavy and they need brute computaion power. Hence CUDA enabled GPU is bare minimum requirement. Although we are working to make this package work for CPUs as well. But some of the methods may not work or may take very long time with CPU.
The following packages (with dependencies) needed to be installed with ur python packages for this library to work
- numpy
- matplotlib
- torch
- torchio
- tochvision
- Captum [1]
- tochray [2]
- lime [6]
- torch-lucent [3]
- CNN Visualization [4]
It is recommended to use conda distribution with python 3.6 environment, Torch 1.6 with Cuda 10.2
To install torchray you must have python cocoapi installaed. This installation is not straight forward with pip. Please follow the following link for cocoapi installation. and then install torch ray.
The library is continuously evolving. In first phase we will ensemble methods from other third party libraries. And next we will implement our own methods to augment the capability of this package. As of now the implemented methods are:
- Captum
- Saliency
- Integrated gradients
- Feature ablation
- Guided Backpropagation
- Deep lift
- De-convolution
- Guided grad-cam
- Layer activation
- Layer conductance
- Layer grad shap
- Gradient shap
- Internal influence
- Inpt X-gradient
- Deep lift shap
- Layer gradient X-Activation
- Layer Deep lift
- Layer Grad Cam
- Shaley Value Sampling
Shaley value sampling takes lot of time for image input. Not recommended to perform on image data.
-
CNN Visualization
- Guided Backpropagation
- Integrated Gradients
- Guided grad cam
- Score cam
- Vanilla Backpropagation
- Grad Cam
- Grad Image Times
- Layer Activation Guided Backpropagation
- Layer Visualization
- Deep Dream
-
Torch Ray
- Excitation backpropagation
- Contrastive excitation backpropagation
- Rise
- Deconv
- Grad Cam
- Gradient
- Guided Backpropagation
- Linear Approx.
-
Lucent
- Render visualization
-
Lime
- Lime segmentation- Using Image explainer
More information regarding the implemented methods and their corresponding parameters are available in: https://github.com/soumickmj/TorchEsegeta/blob/master/EsegetaMethodInfo.pdf
- Exception Handling
- Logging
- Side by side visualization of multiple target attributions
- Extended for 3d models
- Timeout for long running methods
- Multi GPU and multi threading
- Automatic mixed precision support
- Extended for patch based models.
This pipeline is built on top of third party libraries, hence this pipeline is affected/effected by the underlying implementation, complexity and constraints of those. We tried generalize as most as we can. But still, configuring this pipeline itself presents a challenge to the user. But it is a one time activity and worth the pain, as this can produce output from various methods from various libraries. Hence follow the below instructions minutely. For example check the pipelineTester.py file and run it as is with providing the 2D segmentation configuration json file, marking any method use as True.
For classification task the model can be given as it is. but for segmentation task models's forward function needs to be updated with few lines of custom wrapper code so that the output becomes batch*totalClass. for example check the default section of pipeline.py and modelWrapper.py. As of now the wrapper_fnction argument of interpret_explain function from Pipeline class takes two values "threshold_based" (default) and "multi_class" for threshold based segmemnation models and for multiclass semantic segmentation models. You need to pass this argument based on the model. The Classification models do not required any specific wrapper if not a very special case. Alternatively you can also mention wrapper function separately. those function need to be declared in the methods.py file under Interpretability class. and in configuration json file mark the 'aux_func_flag' as true. This alternate method doesn't work for all the methods, due to constraints in underlying libraries. Or you can personally add more wrappers in sermentWrapper.py file and update the dictionary maintained in pipeline.py file.
As of now test it with sigle input only. In future version we will enable the multi input feature. This is done to take away some complexity of the pipeline. Once we are happy with its performance , this functionality can be added with minor changes. Additionally if any sort of zooming required (for non patch based testing), this needs to be handled in the pipeline.py file for variable inp or can be manually done before giving it to the pipeline. And in the configuration json file attributes 'isDepthFirst' and 'is_3d' needs to be mentioned as per model requirement.
Use Pipeline class or Pipeline multi thread class as per your requirement
Pass the path for the desired configuration file in pipeline-tester.py. For your help this library contains four sample configuration json files, two for 2d models and rest two for 3d models for both Segmentation and Classification task.
Also see the confuguration file tag descriptions for understanding.
Tag summary :
is_3d :
Make it True if it is a 3D model.
isDepthFirst :
Make it True if the model accepts input in depth first manner. Make sure to provide input in the same manner.
batch_dim_present :
: True or False depending on whether batch dimension is present in input data.
default :
: Generalle keep it False. Only make this true if you want to test with default models provided with this library. Althoug only 2d Segmentation will work as other default models require weight checkpoint, which is not provided along with the library as they huge in size and proprietary.
dataset :
Mention your dataset name.
test_run :
Run reference number.
patch_overlap :
If using patcher, this the overlap pixel count between patches. Otherwise make it 0
patch_size :
if using patcher, this is the patch size. This must be set to -1 if patcher is not used.
amp_enbled :
Make it True to use Automatic mixed precision.
share_gpu_threads
Number of methods runnign on same gpu per thread.
timeout_enabled
Make it true if you want to enable timeout functionality for long running methods. (Only for Linux)
log_level :
Level of information you want to see in the log file.
uncertainity_metrics :
Make this flag False for all methods as of now. This is for a future functionality. Making it True may generate unexpected details.
uncertainity_cascading :
Make this flag 0 for all methods. This is for a future functionality. Making it True may generate unexpected details.
For all other method related tags please check the documentation from the mentioned library.(Mentioned in references)
- Handling multiple inputs at time
- Uncertainty/Evaluation methods for generated attributions
2D classification-Lucent | 2D Segmentation-Captum GBP | 3D Segmentation-CNN Vis Vanilla BackProp |
---|---|---|
[1] Captum : https://captum.ai/
[2] Torchray : https://github.com/facebookresearch/TorchRay
[3] Lucent: https://github.com/greentfrapp/lucent
[4] CNN Visualization : https://github.com/utkuozbulak/pytorch-cnn-visualizations
[5] DS6 Paper: https://arxiv.org/pdf/2006.10802.pdf
[6] Lime: https://github.com/marcotcr/lime
If you like this repository, please click on Star!
If you use any of our approaches in your research or use codes from this repository, please cite one of the following (or both) in your publications:
TorchEsegeta pipeline, including the methods for Segmentation:-
BibTeX entry:
@article{chatterjee2022torchesegeta,
author = {Chatterjee, Soumick and Das, Arnab and Mandal, Chirag and Mukhopadhyay, Budhaditya and Vipinraj, Manish and Shukla, Aniruddh and Nagaraja Rao, Rajatha and Sarasaen, Chompunuch and Speck, Oliver and Nürnberger, Andreas},
title = {TorchEsegeta: Framework for Interpretability and Explainability of Image-Based Deep Learning Models},
journal = {Applied Sciences},
volume = {12},
year = {2022},
number = {4},
article-number = {1834},
url = {https://www.mdpi.com/2076-3417/12/4/1834},
issn = {2076-3417},
doi = {10.3390/app12041834}
}
This was also presented as an abstract at ISMRM 2021:
Initial version of TorchEsegeta for Classification models can be cited using:-
BibTeX entry:
@Article{jimaging10020045,
AUTHOR = {Chatterjee, Soumick and Saad, Fatima and Sarasaen, Chompunuch and Ghosh, Suhita and Krug, Valerie and Khatun, Rupali and Mishra, Rahul and Desai, Nirja and Radeva, Petia and Rose, Georg and Stober, Sebastian and Speck, Oliver and Nürnberger, Andreas},
TITLE = {Exploration of Interpretability Techniques for Deep COVID-19 Classification Using Chest X-ray Images},
JOURNAL = {Journal of Imaging},
VOLUME = {10},
YEAR = {2024},
NUMBER = {2},
ARTICLE-NUMBER = {45},
URL = {https://www.mdpi.com/2313-433X/10/2/45},
ISSN = {2313-433X},
DOI = {10.3390/jimaging10020045}
}
Thank you so much for your support.