diff --git a/src/graph.jl b/src/graph.jl index 8a4db55f..df5422eb 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -198,6 +198,104 @@ function Base.show(io::IO, ::MIME"text/plain", d::HeteroGraph) end end +struct TemporalSnapshotsGraph <: AbstractGraph + num_nodes::Vector{Int} + num_edges::Vector{Int} + num_snapshots::Int + snapshots::Vector{Graph} + graph_data::Any +end + + +""" + TemporalSnapshotsGraph(; kws...) + +A type that represents a temporal snapshot graph as a sequence of [`Graph`](@ref)s and can store graph data. + +Nodes are indexed in `1:num_nodes` and snapshots are indexed in `1:num_snapshots`. + +# Keyword Arguments + +- `num_nodes`: a vector containing the number of nodes at each snapshot. +- `edge_index`: a tuple containing three vectors. + The first vector contains the list of the source nodes of each edge, the second the target nodes at the third contains the snapshot at which each edge exists. + Defaults to `(Int[], Int[], Int[])`. +- `node_data`: node-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays. + The arrays' last dimension size should be equal to the number of nodes. + Default `nothing`. +- `edge_data`: edge-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays. + The arrays' last dimension size should be equal to the number of edges. + Default `nothing`. +- `graph_data`: graph-related data. Can be `nothing`, or a named tuple of arrays or a dictionary of arrays. + +# Examples + +```julia-repl +julia> tg = MLDatasets.TemporalSnapshotsGraph(num_nodes = [10,10,10], edge_index= ([1,3,4,5,6,7,8],[2,6,7,1,2,10,9],[1,1,1,2,2,3,3]), node_data=[rand(3,10), rand(4,10), rand(2,10)]) +TemporalSnapshotsGraph: + num_nodes => 3-element Vector{Int64} + num_edges => 3-element Vector{Int64} + num_snapsh => 3 + snapshots => 3-element Vector{Main.MLDatasets.Graph} + graph_data => nothing + +julia> tg.snapshots[1] # access the first snapshot +Graph: + num_nodes => 10 + num_edges => 3 + edge_index => ("3-element Vector{Int64}", "3-element Vector{Int64}") + node_data => 3×10 Matrix{Float64} + edge_data => nothing +``` +""" +function TemporalSnapshotsGraph(; + num_nodes::Vector{Int}, + edge_index::Tuple{Vector{Int}, Vector{Int}, Vector{Int}} = (Int[], Int[], Int[]), + node_data:: Union{Vector,Nothing} = nothing, + edge_data:: Union{Vector,Nothing} = nothing, + graph_data = nothing) + + u, v, t = edge_index + @assert length(u) == length(v) == length(t) + num_snapshots = maximum(t) + if !isnothing(node_data) && !isnothing(edge_data) + @assert length(node_data) == length(edge_data) == num_snapshots + end + + snapshots = Vector{Graph}(undef, num_snapshots) + num_edges = Vector{Int}(undef, num_snapshots) + for i in 1:num_snapshots + if !isnothing(node_data) && !isnothing(edge_data) + snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i], edge_data[i]) + elseif !isnothing(node_data) + snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i],nothing) + elseif !isnothing(edge_data) + snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, edge_data[i]) + else + snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, nothing) + end + snapshots[i] = snapshot + num_edges[i] = sum(t.==i) + end +return TemporalSnapshotsGraph(num_nodes, num_edges, num_snapshots, snapshots, graph_data) +end + +function Base.show(io::IO, d::TemporalSnapshotsGraph) + print(io, "TemporalSnapshotsGraph($(d.num_nodes), $(d.num_edges), $(d.num_snapshots))") +end + +function Base.show(io::IO, ::MIME"text/plain", d::TemporalSnapshotsGraph) + recur_io = IOContext(io, :compact => false) + print(io, "TemporalSnapshotsGraph:") + for f in fieldnames(TemporalSnapshotsGraph) + if !startswith(string(f), "_") + fstring = leftalign(string(f), 10) + print(recur_io, "\n $fstring => ") + print(recur_io, "$(_summary(getfield(d, f)))") + end + end +end + # Transform an adjacency list to edge index. # If inneigs = true, assume neighbors from incoming edges. function adjlist2edgeindex(adj; inneigs = false)