Skip to content

Commit

Permalink
Fix for newer pytorch (danielzuegner#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
emsi committed Jan 8, 2023
1 parent 362ec53 commit c80655d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions code_transformer/preprocessing/graph/alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def all_pairs_shortest_paths(edges=None, G=None, directed=False, cutoff=None):
create_using = nx.Graph
G = nx.from_edgelist(edges, create_using=create_using)
sps = nx.all_pairs_dijkstra_path_length(G, cutoff=cutoff)
values = torch.tensor([(dct[0], key, value) for dct in sps for key, value in dct[1].items()],
values = torch.tensor([(dct[0], key, int(value)) for dct in sps for key, value in dct[1].items()],
dtype=torch.long)
return values

Expand Down Expand Up @@ -95,7 +95,7 @@ def next_sibling_shortest_paths(tree_edges):
sibling_edges = next_sibling_edges(tree_edges).numpy()
G_siblings = nx.from_edgelist(sibling_edges, create_using=nx.DiGraph)
sps = list(nx.all_pairs_dijkstra_path_length(G_siblings))
sibling_sp_edgelist = torch.tensor([(from_node, to_node, dist)
sibling_sp_edgelist = torch.tensor([(int(from_node), int(to_node), dist)
for from_node, dct in sps
for to_node, dist in dct.items()],
dtype=torch.long)
Expand Down

0 comments on commit c80655d

Please sign in to comment.