A generalized implementation of Selvaraju et al.'s Grad-CAM for the Flax neural network library.
This library works for any convolutional neural network written in Flax, that has image input.
First install it with:
pip install -U git+https://github.com/codymlewis/flax_gradcam.git
Then make sure import fgradcam
to use this library.
Finally, there are three lines of code needed to compute the Grad-CAM heatmap and plot it. The first line is added to the Flax linen module, after the convolutional layer that you want to analyse:
x = fgradcam.observe(self, x)
With that in place, after training the model, we compute the Grad-CAM heatmaps on the desired samples X
with:
heatmaps = fgradcam.compute(model, variables, X)
Finally, a heatmap can be visualized with:
fgradcam.plot(X[0], heatmaps[0])
plt.show() # Assuming matplotlib.pyplot was imported as plt
A full sample is shown in samples/cnn.py
, in addition to an example of performing the Grad-CAM computation on a pretrained
model in samples/transfer.ipynb