-
Notifications
You must be signed in to change notification settings - Fork 1
Home
Welcome to the Node-Classification-with-Graph-Neural-Networks-FNN
Setup
import os import pandas as pd import numpy as np import networkx as nx import matplotlib.pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers
Prepare the Dataset The Cora dataset consists of 2,708 scientific papers classified into one of seven classes. The citation network consists of 5,429 links. Each paper has a binary word vector of size 1,433, indicating the presence of a corresponding word.
Download the dataset The dataset has two tap-separated files: cora.cites and cora.content.
The cora.cites includes the citation records with two columns: cited_paper_id (target) and citing_paper_id (source). The cora.content includes the paper content records with 1,435 columns: paper_id, subject, and 1,433 binary features. Let's download the dataset.
zip_file = keras.utils.get_file( fname="cora.tgz", origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz", extract=True, ) data_dir = os.path.join(os.path.dirname(zip_file), "cora")
Process and visualize the dataset Then we load the citations data into a Pandas DataFrame.
Introduction Many datasets in various machine learning (ML) applications have structural relationships between their entities, which can be represented as graphs. Such application includes social and communication networks analysis, traffic prediction, and fraud detection. Graph representation Learning aims to build and train models for graph datasets to be used for a variety of ML tasks.
This example demonstrate a simple implementation of a Graph Neural Network (GNN) model. The model is used for a node prediction task on the Cora dataset to predict the subject of a paper given its words and citations network.
Note that, we implement a Graph Convolution Layer from scratch to provide better understanding of how they work. However, there is a number of specialized TensorFlow-based libraries that provide rich GNN APIs, such as Spectral, StellarGraph, and GraphNets.