Skip to content

A framework of graph classification baselines which including TUDataset Loader, GNN models and visualization.

License

Notifications You must be signed in to change notification settings

AngusMonroe/Graph-Classification-Baseline

Repository files navigation

Graph-Classification-Baseline

A framework of graph classification baselines which including TUDataset Loader, GNN models and visualization.

Feature

  • Datasets

    Implemented based on LegacyTUDataset of dgl.data. All TUDataset is avalible. It will download auotomatically when using.

  • GNN models

    Including GCN, GAT, Gated GCN, GIN, GraphSAGE, etc. Deafult configurations can be find in configs/.

    Models with configs having 500k trainable parameters

    Rank Model #Params Test Acc ± s.d. Links
    1 GatedGCN-PE 505421 86.363 ± 0.127 Paper
    2 RingGNN 504766 86.244 ± 0.025 Paper
    3 MoNet 511487 85.582 ± 0.038 Paper
    4 GatedGCN 502223 85.568 ± 0.088 Paper
    5 GIN 508574 85.387 ± 0.136 Paper
    6 3WLGNN 502872 85.341 ± 0.207 Paper
    7 GAT 526990 78.271 ± 0.186 Paper
    8 GCN 500823 71.892 ± 0.334 Paper
    9 GraphSage 502842 50.492 ± 0.001 Paper
  • Visualization

Setup

All test on server 201 with python3.9 and CUDA 11.1. An environment is available at /home/jiaxing/anaconda3/envs/xjx.

pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cu111.html
pip install matplotlib
pip install tenserboardX

Usage

Quick Start

python main.py --config configs/TUs_graph_classification_GCN_DD_100k.json

Output and Checkpoints

Output results are located in the folder defined by the variable out_dir in the corresponding config file.

If out_dir = 'out/TUs_graph_classification/', then

  • Go to out/TUs_graph_classification/results to view all result text files.

  • Directory out/TUs_graph_classification/checkpoints contains model checkpoints.

  • To see the training logs in Tensorboard on local machine

    • Go to the logs directory, i.e. out/TUs_graph_classification/logs/.

    • Run tensorboard --logdir='./' --port 6006.

    • Open http://localhost:6006 in your browser. Note that the port information (here 6006 but it may change) appears on the terminal immediately after starting tensorboard.

Design your own model

New graph layer

Add a class MyGraphLayer() in my_graph_layer.py file in the layers/ directory. A standard code is

import torch
import torch.nn as nn
import dgl

class MyGraphLayer(nn.Module):
    
    def __init__(self, in_dim, out_dim, dropout):
        super().__init__()

        # write your code here
        
    def forward(self, g_in, h_in, e_in):
        
        # write your code here
        # write the dgl reduce and updates function call here

        return h_out, e_out

Directory layers/ contains all layer classes for all graph networks and standard layers like MLP for readout layers.

As instance, the GCN class GCNLayer() is defined in the layers/gcn_layer.py file.


2. New graph network

Add a class MyGraphNetwork() in my_gcn_net.py file in the net/ directory. The loss() function of the network is also defined in class MyGraphNetwork().

import torch
import torch.nn as nn
import dgl

from layers.my_graph_layer import MyGraphLayer

class MyGraphNetwork(nn.Module):
    
    def __init__(self, in_dim, out_dim, dropout):
        super().__init__()

        # write your code here
        self.layer = MyGraphLayer()
        
    def forward(self, g_in, h_in, e_in):
        
        # write your code here
        # write the dgl reduce and updates function call here

        return h_out

    def loss(self, pred, label):

        # write your loss function here

        return loss

Add a name MyGNN for the proposed new graph network class in load_gnn.py file in the net/ directory.

from nets.my_gcn_net import MyGraphNetwork

def MyGNN(net_params):
    return MyGraphNetwork(net_params)

def gnn_model(MODEL_NAME, net_params):
    models = {
        'MyGNN': MyGNN
    }
    return models[MODEL_NAME](net_params)

For example, GCNNet() in nets/gcn_net.py is given the GNN name GCN in nets/load_net.py.

Reference

Benchmarking Graph Neural Networks [paper] [Github] https://codechina.csdn.net/mirrors/dmlc/dgl/-/tree/master/examples/pytorch

About

A framework of graph classification baselines which including TUDataset Loader, GNN models and visualization.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published