From c80655db1be326663d4c387e8dadde1b0ce6ba44 Mon Sep 17 00:00:00 2001 From: Mariusz Woloszyn Date: Sun, 8 Jan 2023 19:52:52 +0100 Subject: [PATCH] Fix for newer pytorch (https://github.com/danielzuegner/code-transformer/issues/26) --- code_transformer/preprocessing/graph/alg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code_transformer/preprocessing/graph/alg.py b/code_transformer/preprocessing/graph/alg.py index fa903af..6572d55 100644 --- a/code_transformer/preprocessing/graph/alg.py +++ b/code_transformer/preprocessing/graph/alg.py @@ -42,7 +42,7 @@ def all_pairs_shortest_paths(edges=None, G=None, directed=False, cutoff=None): create_using = nx.Graph G = nx.from_edgelist(edges, create_using=create_using) sps = nx.all_pairs_dijkstra_path_length(G, cutoff=cutoff) - values = torch.tensor([(dct[0], key, value) for dct in sps for key, value in dct[1].items()], + values = torch.tensor([(dct[0], key, int(value)) for dct in sps for key, value in dct[1].items()], dtype=torch.long) return values @@ -95,7 +95,7 @@ def next_sibling_shortest_paths(tree_edges): sibling_edges = next_sibling_edges(tree_edges).numpy() G_siblings = nx.from_edgelist(sibling_edges, create_using=nx.DiGraph) sps = list(nx.all_pairs_dijkstra_path_length(G_siblings)) - sibling_sp_edgelist = torch.tensor([(from_node, to_node, dist) + sibling_sp_edgelist = torch.tensor([(int(from_node), int(to_node), dist) for from_node, dct in sps for to_node, dist in dct.items()], dtype=torch.long)