This repository implements GCNs in JAX (check it out on github). The code contains the model definition of a Graph Convolutional Network with two graph convolutional layers, following the model used in the paper Semi-Supervised Classification with Graph Convolutional Networks.
Run
python train.py
to train a model on the Cora dataset.
I implemented a sparse matrix multiplication function to support sparse adjacency matrices, which is enabled by default. If you get any error with it, it can be disabled by adding the flag --no-sparse
to the run command.
This is an implementation in JAX of the Semi-Supervised Classification with Graph Convolutional Networks paper. If you use it in your research, please cite the paper:
@article{kipf2016semi,
title={Semi-Supervised Classification with Graph Convolutional Networks},
author={Kipf, Thomas N and Welling, Max},
journal={arXiv preprint arXiv:1609.02907},
year={2016}
}