Skip to content

Commit

Permalink
Add TemporalSnapshotsGraph type (JuliaML#221)
Browse files Browse the repository at this point in the history
* Add TemporalSnapshotGraph

* Add docs
  • Loading branch information
aurorarossi authored Oct 2, 2023
1 parent d4986f1 commit 60a2f05
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,104 @@ function Base.show(io::IO, ::MIME"text/plain", d::HeteroGraph)
end
end

struct TemporalSnapshotsGraph <: AbstractGraph
num_nodes::Vector{Int}
num_edges::Vector{Int}
num_snapshots::Int
snapshots::Vector{Graph}
graph_data::Any
end


"""
TemporalSnapshotsGraph(; kws...)
A type that represents a temporal snapshot graph as a sequence of [`Graph`](@ref)s and can store graph data.
Nodes are indexed in `1:num_nodes` and snapshots are indexed in `1:num_snapshots`.
# Keyword Arguments
- `num_nodes`: a vector containing the number of nodes at each snapshot.
- `edge_index`: a tuple containing three vectors.
The first vector contains the list of the source nodes of each edge, the second the target nodes at the third contains the snapshot at which each edge exists.
Defaults to `(Int[], Int[], Int[])`.
- `node_data`: node-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays.
The arrays' last dimension size should be equal to the number of nodes.
Default `nothing`.
- `edge_data`: edge-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays.
The arrays' last dimension size should be equal to the number of edges.
Default `nothing`.
- `graph_data`: graph-related data. Can be `nothing`, or a named tuple of arrays or a dictionary of arrays.
# Examples
```julia-repl
julia> tg = MLDatasets.TemporalSnapshotsGraph(num_nodes = [10,10,10], edge_index= ([1,3,4,5,6,7,8],[2,6,7,1,2,10,9],[1,1,1,2,2,3,3]), node_data=[rand(3,10), rand(4,10), rand(2,10)])
TemporalSnapshotsGraph:
num_nodes => 3-element Vector{Int64}
num_edges => 3-element Vector{Int64}
num_snapsh => 3
snapshots => 3-element Vector{Main.MLDatasets.Graph}
graph_data => nothing
julia> tg.snapshots[1] # access the first snapshot
Graph:
num_nodes => 10
num_edges => 3
edge_index => ("3-element Vector{Int64}", "3-element Vector{Int64}")
node_data => 3×10 Matrix{Float64}
edge_data => nothing
```
"""
function TemporalSnapshotsGraph(;
num_nodes::Vector{Int},
edge_index::Tuple{Vector{Int}, Vector{Int}, Vector{Int}} = (Int[], Int[], Int[]),
node_data:: Union{Vector,Nothing} = nothing,
edge_data:: Union{Vector,Nothing} = nothing,
graph_data = nothing)

u, v, t = edge_index
@assert length(u) == length(v) == length(t)
num_snapshots = maximum(t)
if !isnothing(node_data) && !isnothing(edge_data)
@assert length(node_data) == length(edge_data) == num_snapshots
end

snapshots = Vector{Graph}(undef, num_snapshots)
num_edges = Vector{Int}(undef, num_snapshots)
for i in 1:num_snapshots
if !isnothing(node_data) && !isnothing(edge_data)
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i], edge_data[i])
elseif !isnothing(node_data)
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i],nothing)
elseif !isnothing(edge_data)
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, edge_data[i])
else
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, nothing)
end
snapshots[i] = snapshot
num_edges[i] = sum(t.==i)
end
return TemporalSnapshotsGraph(num_nodes, num_edges, num_snapshots, snapshots, graph_data)
end

function Base.show(io::IO, d::TemporalSnapshotsGraph)
print(io, "TemporalSnapshotsGraph($(d.num_nodes), $(d.num_edges), $(d.num_snapshots))")
end

function Base.show(io::IO, ::MIME"text/plain", d::TemporalSnapshotsGraph)
recur_io = IOContext(io, :compact => false)
print(io, "TemporalSnapshotsGraph:")
for f in fieldnames(TemporalSnapshotsGraph)
if !startswith(string(f), "_")
fstring = leftalign(string(f), 10)
print(recur_io, "\n $fstring => ")
print(recur_io, "$(_summary(getfield(d, f)))")
end
end
end

# Transform an adjacency list to edge index.
# If inneigs = true, assume neighbors from incoming edges.
function adjlist2edgeindex(adj; inneigs = false)
Expand Down

0 comments on commit 60a2f05

Please sign in to comment.