Skip to content

Commit

Permalink
fix edge_index
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 25, 2024
1 parent 9facb11 commit 77ff3cd
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Manifest.toml
.DS_Store
/test.jl
/test.ipynb
.CondaPkg/
temp/
.vscode/
data/
2 changes: 2 additions & 0 deletions src/inmemorydataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ end

Base.getindex(d::InMemoryGNNDataset, i::Int) = d.graphs[i]
Base.length(d::InMemoryGNNDataset) = length(d.graphs)
Base.iterate(d::InMemoryGNNDataset) = iterate(d.graphs)
Base.iterate(d::InMemoryGNNDataset, i) = iterate(d.graphs, i)

function Base.show(io::IO, d::InMemoryGNNDataset)
if get(io, :compact, false)
Expand Down
58 changes: 32 additions & 26 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function try_from_dlpack(x)
try
return from_dlpack(x)
catch
a = pyconvert(Array, x)
n = ndims(a)
return permutedims(a, ntuple(i -> n-i+1, n))
end
end
# function try_from_dlpack(x)
# # try
# return from_dlpack(x)
# # catch
# # a = pyconvert(Array, x)
# # n = ndims(a)
# # return permutedims(a, ntuple(i -> n-i+1, n))
# # end
# end


"""
Expand All @@ -21,32 +21,38 @@ Since torch tensors are row-major, the corresponding julia arrays
will have permuted dimensions.
"""
function pygdata_to_gnngraph(data)
edge_index = from_dlpack(data.edge_index)
# edge_index needs .to_dense(), otherwise from_dlpack will fail (TODO report to DLPack.jl)
# edge_index = from_dlpack(data.edge_index.to_dense())
num_nodes = pyconvert(Int, data.num_nodes)
num_edges = pyconvert(Int, data.num_edges)
if length(edge_index) > 0
src, dst = edge_index[:,1] .+ 1 , edge_index[:,2] .+ 1
if length(data.edge_index) > 0
py_src, py_dst = data.edge_index
src = from_dlpack(py_src)
dst = from_dlpack(py_dst)
src, dst = src .+ 1 , dst .+ 1
else
src, dst = Int[], Int[]
@assert num_edges == 0
end
@assert length(src) == num_edges
@assert length(dst) == num_edges

@assert all(1 .<= src)
if !all(src .<= num_nodes)
n = maximum(src)
# @warn lazy"Found node index $n in edge index `src`, but only $num_nodes nodes in the graph.
# Updating num_nodes to $n. This message won't be displayed again." maxlog=1
num_nodes = n
end
@assert all(src .<= num_nodes)
@assert all(1 .<= dst)
if !all(dst .<= num_nodes)
n = maximum(dst)
# @warn lazy"Found node index $n in edge index `dst`, but only $num_nodes nodes in the graph.
# Updating num_nodes to $n. This message won't be displayed again." maxlog=1
num_nodes = n
end
@assert all(dst .<= num_nodes)

# if !all(src .<= num_nodes)
# n = maximum(src)
# @warn lazy"Found node index $n in edge index `src`, but only $num_nodes nodes in the graph.
# Updating num_nodes to $n. This message won't be displayed again."
# num_nodes = n
# end
# if !all(dst .<= num_nodes)
# n = maximum(dst)
# @warn lazy"Found node index $n in edge index `dst`, but only $num_nodes nodes in the graph.
# Updating num_nodes to $n. This message won't be displayed again."
# num_nodes = n
# end

ndata = (;)
edata = (;)
Expand All @@ -58,7 +64,7 @@ function pygdata_to_gnngraph(data)
k == :num_edges && continue
py_x = getproperty(data, k)
if pyisinstance(py_x, torch.Tensor)
x = try_from_dlpack(py_x)
x = from_dlpack(py_x)
last_dim = size(x, ndims(x))
if last_dim == num_nodes && num_nodes != num_edges
ndata = (; ndata..., k => x)
Expand Down
39 changes: 39 additions & 0 deletions test/datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,42 @@ end
@test g.gdata.y isa Matrix{Float32}
@test size(g.gdata.y) == (1, 1)
end

@testitem "MUTAG" setup=[TestModule] begin
using .TestModule
dataset = load_dataset("TUDataset", name="MUTAG")
@test dataset.num_graphs == 188
@test dataset.node_features == [:x]
@test dataset.edge_features == [:edge_attr]
@test dataset.graph_features == [:y]
@test length(dataset) == 188
for g in dataset
@test g.ndata.x isa Matrix{Float32}
@test size(g.ndata.x) == (7, g.num_nodes)
end

g = dataset[31]
src, dst = edge_index(g)
@test src isa Vector{Int}
@test dst isa Vector{Int}
@test src == [ 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7,
7, 8, 8, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14,
14, 14, 15, 15, 16, 16, 16, 17, 18, 19, 19, 19, 20, 21] .+ 1
@test dst == [ 1, 5, 0, 2, 1, 3, 2, 4, 12, 3, 5, 6, 0, 4, 4, 7, 11, 6,
8, 7, 9, 8, 10, 9, 11, 15, 6, 10, 12, 3, 11, 13, 12, 14, 19, 13,
15, 16, 10, 14, 14, 17, 18, 16, 16, 13, 20, 21, 19, 19] .+ 1
end

@testitem "ESOL" setup=[TestModule] begin
using .TestModule
dataset = load_dataset("MoleculeNet", name="ESOL")
@test dataset.num_graphs == 1128
@test dataset.node_features == [:x]
@test dataset.edge_features == [:edge_attr]
@test dataset.graph_features == [:y, :smiles]
@test length(dataset) == 1128
@test dataset[1].smiles == "OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O "
@test dataset[2].smiles == "Cc1occc1C(=O)Nc2ccccc2"
@test dataset[3].smiles == "CC(C)=CCCC(C)=CC(=O)"
@test dataset[4].smiles == "c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43"
end

0 comments on commit 77ff3cd

Please sign in to comment.