This is code for the paper ``Understanding Graph Neural Networks with Generalized Geometric Scattering Transforms''. For tables presented in the paper see notebooks/results_eval.ipynb
.
This code implements a generalized geometric scattering transform implemented in pytorch and pytorch lightning and configured by hydra.
Install dependencies
# clone project
git clone https://github.com/atong01/trainable_symmetry
cd trainable_symmetry
# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv
# install pytorch according to instructions
# https://pytorch.org/get-started/
# install requirements
pip install -r requirements.txt
Copy .env.example
to .env
and configure directories in .env
as needed.
To reproduce experiments in paper (also in scripts/basic.sh
):
python src/train.py -m datamodule.transform_args.alpha=-0.5,-0.25,0.0,0.25,0.5 \
datamodule.dataset=NCI1,NCI109,DD,PROTEINS,MUTAG,PTC_MR,REDDIT-BINARY,REDDIT-MULTI-5K,COLLAB,IMDB-BINARY,IMDB-MULTI \
logger=wandb \
datamodule.transform_args.power=1,2 \
seed=0,1,2,3,4,5,6,7,8,9
python src/train.py -m datamodule.transform_args.alpha=-0.5,-0.25,0.0,0.25,0.5 \
datamodule.dataset=NCI1,NCI109,DD,PROTEINS,MUTAG,PTC_MR,REDDIT-BINARY,REDDIT-MULTI-5K,COLLAB,IMDB-BINARY,IMDB-MULTI \
logger=wandb \
datamodule.transform_args.power=1 \
+datamodule.transform_args.cheb_order=10,100\
seed=0,1,2,3,4,5,6,7,8,9
# train on CPU
python src/train.py trainer=cpu
# train on GPU
python src/train.py trainer=gpu
You can override any parameter from command line like this
python src/train.py trainer.max_epochs=20
@misc{perlmutter_understanding_2019,
doi = {10.48550/ARXIV.1911.06253},
url = {https://arxiv.org/abs/1911.06253},
author = {Perlmutter, Michael and Gao, Feng and Wolf, Guy and Hirn, Matthew},
keywords = {Machine Learning (stat.ML), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Understanding Graph Neural Networks with Asymmetric Geometric Scattering Transforms},
publisher = {arXiv},
year = {2019},
}