Skip to content

Commit

Permalink
[github-action] formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Nov 8, 2023
1 parent 8059b64 commit d996ba0
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions GraphHD_v2/graphhd_std_centrality.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def sparse_stochastic_graph(G):

def centrality(data):
degree_centrality = data.edge_index[0].bincount(minlength=data.num_nodes)
degree_ranked_nodes = sorted(range(data.num_nodes), key=lambda node: degree_centrality[node])
degree_ranked_nodes = sorted(
range(data.num_nodes), key=lambda node: degree_centrality[node]
)

def semi_local_centrality(data):
G = nx.Graph()
Expand All @@ -69,14 +71,18 @@ def semi_local_centrality(data):
semi_local_centrality = []

for node in G.nodes():
ego_graph = nx.ego_graph(G, node, radius=2) # Adjust the radius (2 in this case)
ego_graph = nx.ego_graph(
G, node, radius=2
) # Adjust the radius (2 in this case)
semi_local_centrality.append(len(ego_graph))

# Store the semi-local centrality scores in the PyTorch Geometric Data object
data.semi_local_centrality = torch.tensor(semi_local_centrality)

# Rank nodes based on semi-local centrality
semi_local_ranked_nodes = sorted(G.nodes(), key=lambda node: semi_local_centrality[node])
semi_local_ranked_nodes = sorted(
G.nodes(), key=lambda node: semi_local_centrality[node]
)
return semi_local_ranked_nodes

def degree_centrality(data):
Expand Down Expand Up @@ -163,7 +169,6 @@ def current_flow_betweeness_centrality(data):

return scores_nodes


def edge_current_flow_betweeness_centrality(data):
G = to_networkx(data)
G = G.to_undirected()
Expand Down Expand Up @@ -401,7 +406,6 @@ def forward(self, x):
pr = pagerank(x)
pr_sort, order = pr.sort()


node_id_hvs = torchhd.empty(x.num_nodes, self.out_features, VSA)
node_id_hvs[order] = self.node_ids.weight[: x.num_nodes]

Expand Down Expand Up @@ -447,13 +451,21 @@ def forward(self, x):
f = f1.compute().item() * 100
return acc, f, train_t, test_t


REPETITIONS = 50
RANDOMNESS = ["random"]
# DATASET = ["PTC_FM", "MUTAG", "NCI1", "ENZYMES", "PROTEINS", "DD"]
METRICS = ["none","page_rank","degree_centrality",
"closeness_centrality",
"betweenness_centrality", "load_centrality", "subgraph_centrality",
"subgraph_centrality_exp", "harmonic_centrality"]
METRICS = [
"none",
"page_rank",
"degree_centrality",
"closeness_centrality",
"betweenness_centrality",
"load_centrality",
"subgraph_centrality",
"subgraph_centrality_exp",
"harmonic_centrality",
]
DATASET = ["PTC_FM", "MUTAG", "NCI1", "ENZYMES", "PROTEINS", "DD"]
# VSAS = ["BSC", "MAP", "HRR", "FHRR"]
VSAS = ["FHRR"]
Expand Down Expand Up @@ -493,7 +505,7 @@ def forward(self, x):
"accuracy",
"f1",
"VSA",
"metric"
"metric",
]
)
writer.writerows(
Expand All @@ -506,7 +518,7 @@ def forward(self, x):
acc_final[0],
f1_final[0],
VSA,
METRIC
METRIC,
]
]
)

0 comments on commit d996ba0

Please sign in to comment.