-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate_graph.py
executable file
·73 lines (50 loc) · 1.69 KB
/
generate_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
# encoding: UTF-8
"""
Filename: generate_graph.py
Author: David Oniani
E-mail: oniani.david@mayo.edu
Description:
Create a pickle file for use in node2vec.
"""
import os
import pickle
import networkx as nx
import numpy as np
import pandas as pd
EDGES_FILE: str = "data/edges.csv"
FEATS_FILE: str = "data/features.csv"
PICKLE_DIR: str = "pickle"
PICKLE_FILE: str = "adj_feat.pkl"
def main() -> None:
"""The main function."""
# Read edge list
g = nx.read_edgelist(EDGES_FILE, delimiter=",", nodetype=int)
# Add root
# NOTE: The root is directly connected to all other nodes
g.add_node(int(0))
for node in g.nodes():
if node != int(0):
g.add_edge(int(0), node)
# Read feature list
df = pd.read_csv(FEATS_FILE, index_col=0)
# Add features from dataframe to networkx nodes
for node_idx, features_series in df.iterrows():
if not g.has_node(node_idx):
g.add_node(node_idx)
g.add_edge(node_idx, int(0))
g.nodes[node_idx]["features"] = features_series.values
# Make sure the graph is connected
assert nx.is_connected(g) is True
# Get adjacency matrix in sparse format (sorted by g.nodes())
adj = nx.adjacency_matrix(g)
# Get features matrix (also sorted by g.nodes())
features = np.zeros((df.shape[0], df.shape[1])) # num nodes, num features
for idx, node in enumerate(g.nodes()):
features[idx, :] = g.nodes[node]["features"]
# Save adj, features in pickle file
network_tuple = (adj, features)
with open(os.path.join(PICKLE_DIR, PICKLE_FILE), "wb") as f:
pickle.dump(network_tuple, f)
if __name__ == "__main__":
main()