diff --git a/LICENSE b/LICENSE index 33804ce..4721f7b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 JuliaHub, Inc. and other contributors +Copyright (c) JuliaHub, Inc. and other contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Project.toml b/Project.toml index a208be9..cdf3808 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,33 @@ name = "StateSelection" uuid = "64909d44-ed92-46a8-bbd9-f047dfbdc84b" +version = "0.1.0-DEV" authors = ["JuliaHub", "Inc. and other contributors"] -version = "1.0.0-DEV" + +[deps] +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] +DocStringExtensions = "0.9.3" +FindFirstFunctions = "1.2.0" +Graphs = "1.10.0" +LinearAlgebra = "1.11.0" +Setfield = "1.1.1" +SparseArrays = "1.11.0" +UnPack = "1.0.2" julia = "1.9" +[weakdeps] +DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" + +[extensions] +StateSelectionDeepDiffsExt = "DeepDiffs" + [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index eaa0f5a..5b3fa2e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,79 @@ # StateSelection -[![Build Status](https://github.com/Keno/StateSelection.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Keno/StateSelection.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/Keno/StateSelection.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Keno/StateSelection.jl) -[![Coverage](https://coveralls.io/repos/github/Keno/StateSelection.jl/badge.svg?branch=main)](https://coveralls.io/github/Keno/StateSelection.jl?branch=main) +[![Build Status](https://github.com/JuliaComputing/StateSelection.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaComputing/StateSelection.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/JuliaComputing/StateSelection.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaComputing/StateSelection.jl) +[![Coverage](https://coveralls.io/repos/github/JuliaComputing/StateSelection.jl/badge.svg?branch=main)](https://coveralls.io/github/JuliaComputing/StateSelection.jl?branch=main) + +This package implements *structural* transformations suitable +of optimizing systems of (non-linear, ordinary differential, differential algebraic) equations for faster and more stable integration using a numerical solver. It is intended to serve as a common algorithmic core to a variety of downstream modeling systems, including [MTK](https://github.com/SciML/ModelingToolkit.jl), [DAECompiler](https://github.com/CedarEDA/DAECompiler.jl) and JuliaSimCompiler. + +It is intended to be independent of the actual representation of the system of equations, instead operating on structural abstractions defined in this package. In particular, it computes only the transformations that *should* be the done to the system of equations, but the actual transformation of the system itself is deferred to the downstream user of this package. + +## Transformations + +This package implements *partial state selection* PSS. Unfortunately, the terminology is not entirely consistent in the literature, so before we proceed, we define PSS problem. + +### The state selection problem + +We will consider a *system* to be the dictum of a list of variables and equations together with + +1. The structural incidence matrix of equations and variables +2. The numerical incidence matrix of the integer-linear constant-coefficient sub-system +3. For each equation-variable pair an indication of whether the downstream compiler is capable of solving the equation for that variable. +4. A graph of differential relations between the variables + +Then, the *PSS problem* is to find + +A. A (possibly empty) list of equations to differentiate +B. The assignment of each variable (or their derivatives if such derivative occurs in a differentiated equation) to one of + + i. A selected differential state + ii. An assignment to an equation (declaring the equation will be solved for this variable) + iii. An assignment to a linear-system of a list of equations + iv. An algebraic state + +Such that +1. the dependency graph of variables (described further below) is acyclic +2. The provided solvability constraints are obeyed. + +### Further remarks + +There are several equivalent ways of thinking of this assignment. Two common ones in the literature are +1. an upper-triangular structural incidence matrix +2. a graph of dependencies between variables + +To form the dependency graph of variables, let the vertices of the graph be the variables. For each variable assigned to an equation, add a directed edge from all other variables in the incidence-row of the said equation to the assigned variable. + +Solutions to the PSS problem are not unique and the chosen solution materially affects the ease and stability of the resulting numerical integration. In general, a numerical integrator may need to switch the set of chosen states dynamically at runtime (known as dynamic state selection). Note that while this package has various hooks to control the selected states, and can compute the tearing sets, it does not by itself implement dynamic state selection. + +# Detailed description of the transformation steps + +A. Structural Singularity Removal + +This is a pre-pass that runs before pantelides. The primary objective of the pantelides algorithm is to ensure that the +jacobian of the fundamental ODE system is non-singular at runtime. However, because pantelides is a structural algorithm, +it can accidentally fail to fully reduce the index for systems which have full structural rank, but are numerically singular. +One common source of such systems are conservation laws commonly used in hierachical modeling tools. + +However, fortunately, in such systems, the numerical singularity is static and apparent in the constant linear coefficients of the integer-linear subsystem (ILS) of the whole system. + +As such, this pre-pass applies a change of basis to the ILS to make the numerical rank-deficiency structurally apparent, allowing pantelides to properly differentiate the resulting equations. + +B. Patenlides algorithm (DAE only) + +The Pantelides algorithm is used for reduxing the differentiation index of the system to index 1 or 0, making it suitable for integration with a numerical integrator (such integrators are generally not capable of integrating higher-index DAE systems). This is accomplished by generating a list of equations to differentiate and adding these to the system. + +C. Dummy Derivatives (DAE only) + +The pantelides algorithm produces systems that are over-determined in their differential relations (the sum of the number of differential relations and algebraic relations needs to match the number of variables). As such, some of these differential relations need to be removed in order to make the system solvable. However, the choice of which relations to remove can radically affect the numerical properties of the system and thus the ease and stability of the resulting integration. + +D. Tearing + +Tearing computes a (directed) dependency graph of equations (or dually variables). If this graph is fully connected and has no cycles, all output variables are uniquely determined given all input variables, so the system will have no algebraic states. In general however, this dependency graph may have cycles (referred to as *algebraic loops*). Such cycles are broken numerically by choosing one or more variables in the cycle as algebraic states and (at runtime) wrapping a non-linear solver around these variables. Again, +the choice of algebraic states greatly affects the ease and stability of the resulting solve or integration. + + +## History + +The code in this package can be traced back to +JuliaComputing/StructuralTransformations.jl, which provided a port of the structural transformation core from [Modia](https://github.com/ModiaSim/Modia.jl) to MTK's data structures. This package was subsequently integrated into MTK, with a number of intervening rewrites, simplifications and cleanups (both in code and understanding), before be-ing re-extracted to become this package. diff --git a/ext/StateSelectionDeepDiffsExt.jl b/ext/StateSelectionDeepDiffsExt.jl new file mode 100644 index 0000000..e675e3b --- /dev/null +++ b/ext/StateSelectionDeepDiffsExt.jl @@ -0,0 +1,187 @@ + +using DeepDiffs, ModelingToolkit +using StateSelection.BipartiteGraphs: Label, + BipartiteAdjacencyList, unassigned, + HighlightInt +using StateSelection: SystemStructure, + MatchedSystemStructure, + SystemStructurePrintMatrix + +""" +A utility struct for displaying the difference between two HighlightInts. + +# Example +```julia +using StateSelection, DeepDiffs + +old_i = HighlightInt(1, :default, true) +new_i = HighlightInt(2, :default, false) +diff = HighlightIntDiff(new_i, old_i) + +show(diff) +``` +""" +struct HighlightIntDiff + new::HighlightInt + old::HighlightInt +end + +function Base.show(io::IO, d::HighlightIntDiff) + p_color = d.new.highlight + (d.new.match && !d.old.match) && (p_color = :light_green) + (!d.new.match && d.old.match) && (p_color = :light_red) + + (d.new.match || d.old.match) && printstyled(io, "(", color = p_color) + if d.new.i != d.old.i + Base.show(io, HighlightInt(d.old.i, :light_red, d.old.match)) + print(io, " ") + Base.show(io, HighlightInt(d.new.i, :light_green, d.new.match)) + else + Base.show(io, HighlightInt(d.new.i, d.new.highlight, false)) + end + (d.new.match || d.old.match) && printstyled(io, ")", color = p_color) +end + +""" +A utility struct for displaying the difference between two +BipartiteAdjacencyList's. + +# Example +```julia +using ModelingToolkit, DeepDiffs + +old = BipartiteAdjacencyList(...) +new = BipartiteAdjacencyList(...) +diff = BipartiteAdjacencyListDiff(new, old) + +show(diff) +``` +""" +struct BipartiteAdjacencyListDiff + new::BipartiteAdjacencyList + old::BipartiteAdjacencyList +end + +function Base.show(io::IO, l::BipartiteAdjacencyListDiff) + print(io, + LabelDiff(Label(l.new.match === true ? "∫ " : ""), + Label(l.old.match === true ? "∫ " : ""))) + (l.new.match !== true && l.old.match !== true) && print(io, " ") + + new_nonempty = isnothing(l.new.u) ? nothing : !isempty(l.new.u) + old_nonempty = isnothing(l.old.u) ? nothing : !isempty(l.old.u) + if new_nonempty === true && old_nonempty === true + if (!isempty(setdiff(l.new.highlight_u, l.new.u)) || + !isempty(setdiff(l.old.highlight_u, l.old.u))) + throw(ArgumentError("The provided `highlight_u` must be a sub-graph of `u`.")) + end + + new_items = Dict(i => HighlightInt(i, :nothing, i === l.new.match) for i in l.new.u) + old_items = Dict(i => HighlightInt(i, :nothing, i === l.old.match) for i in l.old.u) + + highlighted = union(map(intersect(l.new.u, l.old.u)) do i + HighlightIntDiff(new_items[i], old_items[i]) + end, + map(setdiff(l.new.u, l.old.u)) do i + HighlightInt(new_items[i].i, :light_green, + new_items[i].match) + end, + map(setdiff(l.old.u, l.new.u)) do i + HighlightInt(old_items[i].i, :light_red, + old_items[i].match) + end) + print(IOContext(io, :typeinfo => typeof(highlighted)), highlighted) + elseif new_nonempty === true + printstyled( + io, map(l.new.u) do i + HighlightInt(i, :nothing, i === l.new.match) + end, color = :light_green) + elseif old_nonempty === true + printstyled( + io, map(l.old.u) do i + HighlightInt(i, :nothing, i === l.old.match) + end, color = :light_red) + elseif old_nonempty !== nothing || new_nonempty !== nothing + print(io, + LabelDiff(Label(new_nonempty === false ? "∅" : "", :light_black), + Label(old_nonempty === false ? "∅" : "", :light_black))) + else + printstyled(io, '⋅', color = :light_black) + end +end + +""" +A utility struct for displaying the difference between two Labels +in git-style red/green highlighting. + +# Example +```julia +using ModelingToolkit, DeepDiffs + +old = Label("before") +new = Label("after") +diff = LabelDiff(new, old) + +show(diff) +``` +""" +struct LabelDiff + new::Label + old::Label +end +function Base.show(io::IO, l::LabelDiff) + if l.new != l.old + printstyled(io, l.old.s, color = :light_red) + length(l.new.s) != 0 && length(l.old.s) != 0 && print(io, " ") + printstyled(io, l.new.s, color = :light_green) + else + print(io, l.new) + end +end + +""" +A utility struct for displaying the difference between two +(Matched)SystemStructure's in git-style red/green highlighting. + +# Example +```julia +using ModelingToolkit, DeepDiffs + +old = SystemStructurePrintMatrix(...) +new = SystemStructurePrintMatrix(...) +diff = SystemStructureDiffPrintMatrix(new, old) + +show(diff) +``` +""" +struct SystemStructureDiffPrintMatrix <: + AbstractMatrix{Union{LabelDiff, BipartiteAdjacencyListDiff}} + new::SystemStructurePrintMatrix + old::SystemStructurePrintMatrix +end + +function Base.size(ssdpm::SystemStructureDiffPrintMatrix) + max.(Base.size(ssdpm.new), Base.size(ssdpm.old)) +end + +function Base.getindex(ssdpm::SystemStructureDiffPrintMatrix, i::Integer, j::Integer) + checkbounds(ssdpm, i, j) + if i > 1 && (j == 4 || j == 9) + old = new = BipartiteAdjacencyList(nothing, nothing, unassigned) + (i <= size(ssdpm.new, 1)) && (new = ssdpm.new[i, j]) + (i <= size(ssdpm.old, 1)) && (old = ssdpm.old[i, j]) + BipartiteAdjacencyListDiff(new, old) + else + old = new = Label("") + (i <= size(ssdpm.new, 1)) && (new = ssdpm.new[i, j]) + (i <= size(ssdpm.old, 1)) && (old = ssdpm.old[i, j]) + LabelDiff(new, old) + end +end + +function DeepDiffs.deepdiff(old::Union{MatchedSystemStructure, SystemStructure}, + new::Union{MatchedSystemStructure, SystemStructure}) + new_sspm = SystemStructurePrintMatrix(new) + old_sspm = SystemStructurePrintMatrix(old) + Base.print_matrix(stdout, SystemStructureDiffPrintMatrix(new_sspm, old_sspm)) +end diff --git a/src/StateSelection.jl b/src/StateSelection.jl index 8ddca90..26fee2a 100644 --- a/src/StateSelection.jl +++ b/src/StateSelection.jl @@ -1,5 +1,34 @@ module StateSelection -# Write your package code here. +using DocStringExtensions +using Setfield: @set!, @set +using UnPack: @unpack +using Graphs + +# Graph Types +function invview end +function complete end +include("graph/bipartite.jl") +include("graph/diff.jl") +using .BipartiteGraphs + +# Math library +include("math/bareiss.jl") +include("math/sparsematrixclil.jl") +using .CLIL, .bareiss + +# Downstream interferace +include("interface.jl") + +# Structural transformation passes +include("singularity_removal.jl") +include("pantelides.jl") +include("modia_tearing.jl") +include("tearing.jl") +include("partial_state_selection.jl") + +# Utilities +include("debug.jl") +include("utils.jl") end diff --git a/src/StructuralTransformations.jl b/src/StructuralTransformations.jl new file mode 100644 index 0000000..b9aaca3 --- /dev/null +++ b/src/StructuralTransformations.jl @@ -0,0 +1,73 @@ +module StructuralTransformations + +using Setfield: @set!, @set +using UnPack: @unpack + +using Symbolics: unwrap, linear_expansion, fast_substitute +using SymbolicUtils +using SymbolicUtils.Code +using SymbolicUtils.Rewriters +using SymbolicUtils: similarterm, istree + +using ModelingToolkit +using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential, + unknowns, equations, vars, Symbolic, diff2term, value, + operation, arguments, Sym, Term, simplify, solve_for, + isdiffeq, isdifferential, isirreducible, + empty_substitutions, get_substitutions, + get_tearing_state, get_iv, independent_variables, + has_tearing_state, defaults, InvalidSystemException, + ExtraEquationsSystemException, + ExtraVariablesSystemException, + get_postprocess_fbody, vars!, + IncrementalCycleTracker, add_edge_checked!, topological_sort, + invalidate_cache!, Substitutions, get_or_construct_tearing_state, + filter_kwargs, lower_varname, setio, SparseMatrixCLIL, + get_fullvars, has_equations, observed, + Schedule + +using ModelingToolkit.BipartiteGraphs +import .BipartiteGraphs: invview, complete +import ModelingToolkit: var_derivative!, var_derivative_graph! +using Graphs +using ModelingToolkit: algeqs, EquationsView, + SystemStructure, TransformationState, TearingState, + structural_simplify!, + isdiffvar, isdervar, isalgvar, isdiffeq, algeqs, is_only_discrete, + dervars_range, diffvars_range, algvars_range, + DiffGraph, complete!, + get_fullvars, system_subset + +using ModelingToolkit.DiffEqBase +using ModelingToolkit.StaticArrays +using RuntimeGeneratedFunctions: @RuntimeGeneratedFunction, + RuntimeGeneratedFunctions, + drop_expr + +RuntimeGeneratedFunctions.init(@__MODULE__) + +using SparseArrays + +using SimpleNonlinearSolve + +export tearing, partial_state_selection, dae_index_lowering, check_consistency +export dummy_derivative +export build_torn_function, build_observed_function, ODAEProblem +export sorted_incidence_matrix, + pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!, + linear_subsys_adjmat! +export tearing_assignments, tearing_substitution +export torn_system_jacobian_sparsity +export full_equations +export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask +export computed_highest_diff_variables + +include("utils.jl") +include("pantelides.jl") +include("bipartite_tearing/modia_tearing.jl") +include("tearing.jl") +include("symbolics_tearing.jl") +include("partial_state_selection.jl") +include("codegen.jl") + +end # module diff --git a/src/debug.jl b/src/debug.jl new file mode 100644 index 0000000..ed22e37 --- /dev/null +++ b/src/debug.jl @@ -0,0 +1,131 @@ + +using .BipartiteGraphs: Label, BipartiteAdjacencyList +struct SystemStructurePrintMatrix <: + AbstractMatrix{Union{Label, BipartiteAdjacencyList}} + bpg::BipartiteGraph + highlight_graph::Union{Nothing, BipartiteGraph} + var_to_diff::DiffGraph + eq_to_diff::DiffGraph + var_eq_matching::Union{Matching, Nothing} +end + +""" +Create a SystemStructurePrintMatrix to display the contents +of the provided SystemStructure. +""" +function SystemStructurePrintMatrix(s::SystemStructure) + return SystemStructurePrintMatrix(complete(s.graph), + s.solvable_graph === nothing ? nothing : + complete(s.solvable_graph), + complete(s.var_to_diff), + complete(s.eq_to_diff), + nothing) +end +Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 9) +function compute_diff_label(diff_graph, i, symbol) + di = i - 1 <= length(diff_graph) ? diff_graph[i - 1] : nothing + return di === nothing ? Label("") : Label(string(di, symbol)) +end +function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer) + checkbounds(bgpm, i, j) + if i <= 1 + return (Label.(("#", "∂ₜ", " ", " eq", "", "#", "∂ₜ", " ", " v")))[j] + elseif j == 5 + colors = Base.text_colors + return Label("|", :light_black) + elseif j == 2 + return compute_diff_label(bgpm.eq_to_diff, i, '↑') + elseif j == 3 + return compute_diff_label(invview(bgpm.eq_to_diff), i, '↓') + elseif j == 7 + return compute_diff_label(bgpm.var_to_diff, i, '↑') + elseif j == 8 + return compute_diff_label(invview(bgpm.var_to_diff), i, '↓') + elseif j == 1 + return Label((i - 1 <= length(bgpm.eq_to_diff)) ? string(i - 1) : "") + elseif j == 6 + return Label((i - 1 <= length(bgpm.var_to_diff)) ? string(i - 1) : "") + elseif j == 4 + return BipartiteAdjacencyList( + i - 1 <= nsrcs(bgpm.bpg) ? + 𝑠neighbors(bgpm.bpg, i - 1) : nothing, + bgpm.highlight_graph !== nothing && + i - 1 <= nsrcs(bgpm.highlight_graph) ? + Set(𝑠neighbors(bgpm.highlight_graph, i - 1)) : + nothing, + bgpm.var_eq_matching !== nothing && + (i - 1 <= length(invview(bgpm.var_eq_matching))) ? + invview(bgpm.var_eq_matching)[i - 1] : unassigned) + elseif j == 9 + match = unassigned + if bgpm.var_eq_matching !== nothing && i - 1 <= length(bgpm.var_eq_matching) + match = bgpm.var_eq_matching[i - 1] + isa(match, Union{Int, Unassigned}) || (match = true) # Selected Unknown + end + return BipartiteAdjacencyList( + i - 1 <= ndsts(bgpm.bpg) ? + 𝑑neighbors(bgpm.bpg, i - 1) : nothing, + bgpm.highlight_graph !== nothing && + i - 1 <= ndsts(bgpm.highlight_graph) ? + Set(𝑑neighbors(bgpm.highlight_graph, i - 1)) : + nothing, match) + else + @assert false + end +end + +function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure) + @unpack graph, solvable_graph, var_to_diff, eq_to_diff = s + if !get(io, :limit, true) || !get(io, :mtk_limit, true) + print(io, "SystemStructure with ", length(s.graph.fadjlist), " equations and ", + isa(s.graph.badjlist, Int) ? s.graph.badjlist : length(s.graph.badjlist), + " variables\n") + Base.print_matrix(io, SystemStructurePrintMatrix(s)) + else + S = incidence_matrix(s.graph, Num(Sym{Real}(:×))) + print(io, "Incidence matrix:") + show(io, mime, S) + end +end + +struct MatchedSystemStructure + structure::SystemStructure + var_eq_matching::Matching +end + +""" +Create a SystemStructurePrintMatrix to display the contents +of the provided MatchedSystemStructure. +""" +function SystemStructurePrintMatrix(ms::MatchedSystemStructure) + return SystemStructurePrintMatrix(complete(ms.structure.graph), + complete(ms.structure.solvable_graph), + complete(ms.structure.var_to_diff), + complete(ms.structure.eq_to_diff), + complete(ms.var_eq_matching, + nsrcs(ms.structure.graph))) +end + +function Base.copy(ms::MatchedSystemStructure) + MatchedSystemStructure(Base.copy(ms.structure), Base.copy(ms.var_eq_matching)) +end + +function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) + s = ms.structure + @unpack graph, solvable_graph, var_to_diff, eq_to_diff = s + print(io, "Matched SystemStructure with ", length(graph.fadjlist), " equations and ", + isa(graph.badjlist, Int) ? graph.badjlist : length(graph.badjlist), + " variables\n") + Base.print_matrix(io, SystemStructurePrintMatrix(ms)) + printstyled(io, "\n\nLegend: ") + printstyled(io, "Solvable") + print(io, " | ") + printstyled(io, "(Solvable + Matched)", color = :light_yellow) + print(io, " | ") + printstyled(io, "Unsolvable", color = :light_black) + print(io, " | ") + printstyled(io, "(Unsolvable + Matched)", color = :magenta) + print(io, " | ") + printstyled(io, " ∫", color = :cyan) + printstyled(io, " SelectedState") +end diff --git a/src/graph/bipartite.jl b/src/graph/bipartite.jl new file mode 100644 index 0000000..5b5701d --- /dev/null +++ b/src/graph/bipartite.jl @@ -0,0 +1,828 @@ +module BipartiteGraphs + + +export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned, + Matching, InducedCondensationGraph, maximal_matching, + construct_augmenting_path!, MatchedCondensationGraph + +export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors, + 𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!, invview, + delete_srcs!, delete_dsts! + +import ..invview, ..complete + +using DocStringExtensions +using UnPack +using SparseArrays +using Graphs +using Setfield + +### Matching +struct Unassigned + global unassigned + const unassigned = Unassigned.instance +end +# Behaves as a scalar +Base.length(u::Unassigned) = 1 +Base.size(u::Unassigned) = () +Base.iterate(u::Unassigned) = (unassigned, nothing) +Base.iterate(u::Unassigned, state) = nothing + +Base.show(io::IO, ::Unassigned) = printstyled(io, "u"; color = :light_black) + +struct Matching{U, V <: AbstractVector} <: AbstractVector{Union{U, Int}} #=> :Unassigned =# + match::V + inv_match::Union{Nothing, V} +end +# These constructors work around https://github.com/JuliaLang/julia/issues/41948 +function Matching{V}(m::Matching) where {V} + eltype(m) === Union{V, Int} && return M + VUT = typeof(similar(m.match, Union{V, Int})) + Matching{V}(convert(VUT, m.match), + m.inv_match === nothing ? nothing : convert(VUT, m.inv_match)) +end +Matching(m::Matching) = m +Matching{U}(v::V) where {U, V <: AbstractVector} = Matching{U, V}(v, nothing) +function Matching{U}(v::V, iv::Union{V, Nothing}) where {U, V <: AbstractVector} + Matching{U, V}(v, iv) +end +function Matching(v::V) where {U, V <: AbstractVector{Union{U, Int}}} + Matching{@isdefined(U) ? U : Unassigned, V}(v, nothing) +end +function Matching(m::Int) + Matching{Unassigned}(Union{Int, Unassigned}[unassigned for _ in 1:m], nothing) +end +function Matching{U}(m::Int) where {U} + Matching{Union{Unassigned, U}}(Union{Int, Unassigned, U}[unassigned for _ in 1:m], + nothing) +end + +Base.size(m::Matching) = Base.size(m.match) +Base.getindex(m::Matching, i::Integer) = m.match[i] +Base.iterate(m::Matching, state...) = iterate(m.match, state...) +function Base.copy(m::Matching{U}) where {U} + Matching{U}(copy(m.match), m.inv_match === nothing ? nothing : copy(m.inv_match)) +end +function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where {U} + if m.inv_match !== nothing + oldv = m.match[i] + # TODO: maybe default Matching to always have an `inv_match`? + + # To maintain the invariant that `m.inv_match[m.match[i]] == i`, we need + # to unassign the matching at `m.inv_match[v]` if it exists. + if v isa Int && (iv = m.inv_match[v]) isa Int + m.match[iv] = unassigned + end + if isa(oldv, Int) + @assert m.inv_match[oldv] == i + m.inv_match[oldv] = unassigned + end + isa(v, Int) && (m.inv_match[v] = i) + end + return m.match[i] = v +end + +function Base.push!(m::Matching, v) + push!(m.match, v) + if v isa Integer && m.inv_match !== nothing + m.inv_match[v] = length(m.match) + end +end + +function complete(m::Matching{U}, + N = maximum((x for x in m.match if isa(x, Int)); init = 0)) where {U} + m.inv_match !== nothing && return m + inv_match = Union{U, Int}[unassigned for _ in 1:N] + for (i, eq) in enumerate(m.match) + isa(eq, Int) || continue + inv_match[eq] = i + end + return Matching{U}(collect(m.match), inv_match) +end + +@noinline function require_complete(m::Matching) + m.inv_match === nothing && + throw(ArgumentError("Backwards matching not defined. `complete` the matching first.")) +end + +function invview(m::Matching{U, V}) where {U, V} + require_complete(m) + return Matching{U, V}(m.inv_match, m.match) +end + +### +### Edges & Vertex +### +@enum VertType SRC DST + +struct BipartiteEdge{I <: Integer} <: Graphs.AbstractEdge{I} + src::I + dst::I + function BipartiteEdge(src::I, dst::V) where {I, V} + T = promote_type(I, V) + new{T}(T(src), T(dst)) + end +end + +Graphs.src(edge::BipartiteEdge) = edge.src +Graphs.dst(edge::BipartiteEdge) = edge.dst + +function Base.show(io::IO, edge::BipartiteEdge) + @unpack src, dst = edge + print(io, "[src: ", src, "] => [dst: ", dst, "]") +end + +Base.:(==)(a::BipartiteEdge, b::BipartiteEdge) = src(a) == src(b) && dst(a) == dst(b) + +### +### Graph +### +""" +$(TYPEDEF) + +A bipartite graph representation between two, possibly distinct, sets of vertices +(source and dependencies). Maps source vertices, labelled `1:N₁`, to vertices +on which they depend (labelled `1:N₂`). + +# Fields +$(FIELDS) + +# Example +```julia +using BipartiteGraphs + +ne = 4 +srcverts = 1:4 +depverts = 1:2 + +# six source vertices +fadjlist = [[1],[1],[2],[2],[1],[1,2]] + +# two vertices they depend on +badjlist = [[1,2,5,6],[3,4,6]] + +bg = BipartiteGraph(7, fadjlist, badjlist) +``` +""" +mutable struct BipartiteGraph{I <: Integer, M} <: Graphs.AbstractGraph{I} + ne::Int + fadjlist::Vector{Vector{I}} # `fadjlist[src] => dsts` + badjlist::Union{Vector{Vector{I}}, I} # `badjlist[dst] => srcs` or `ndsts` + metadata::M +end +function BipartiteGraph(ne::Integer, fadj::AbstractVector, + badj::Union{AbstractVector, Integer} = maximum(maximum, fadj); + metadata = nothing) + BipartiteGraph(ne, fadj, badj, metadata) +end +function BipartiteGraph(fadj::AbstractVector, + badj::Union{AbstractVector, Integer} = maximum(maximum, fadj); + metadata = nothing) + BipartiteGraph(mapreduce(length, +, fadj; init = 0), fadj, badj, metadata) +end + +@noinline function require_complete(g::BipartiteGraph) + g.badjlist isa AbstractVector || + throw(ArgumentError("The graph has no back edges. Use `complete`.")) +end + +function invview(g::BipartiteGraph) + require_complete(g) + BipartiteGraph(g.ne, g.badjlist, g.fadjlist) +end + +function complete(g::BipartiteGraph{I}) where {I} + isa(g.badjlist, AbstractVector) && return g + badjlist = Vector{I}[Vector{I}() for _ in 1:(g.badjlist)] + for (s, l) in enumerate(g.fadjlist) + for d in l + push!(badjlist[d], s) + end + end + BipartiteGraph(g.ne, g.fadjlist, badjlist) +end + +# Matrix whose only purpose is to pretty-print the bipartite graph +struct BipartiteAdjacencyList + u::Union{Vector{Int}, Nothing} + highlight_u::Union{Set{Int}, Nothing} + match::Union{Int, Bool, Unassigned} +end +function BipartiteAdjacencyList(u::Union{Vector{Int}, Nothing}) + BipartiteAdjacencyList(u, nothing, unassigned) +end + +struct HighlightInt + i::Int + highlight::Symbol + match::Bool +end +Base.typeinfo_implicit(::Type{HighlightInt}) = true +function Base.show(io::IO, hi::HighlightInt) + if hi.match + printstyled(io, "(", color = hi.highlight) + printstyled(io, hi.i, color = hi.highlight) + printstyled(io, ")", color = hi.highlight) + else + printstyled(io, hi.i, color = hi.highlight) + end +end + +function Base.show(io::IO, l::BipartiteAdjacencyList) + if l.match === true + printstyled(io, "∫ ", color = :cyan) + else + printstyled(io, " ") + end + if l.u === nothing + printstyled(io, '⋅', color = :light_black) + elseif isempty(l.u) + printstyled(io, '∅', color = :light_black) + elseif l.highlight_u === nothing + print(io, l.u) + else + match = l.match + isa(match, Bool) && (match = unassigned) + function choose_color(i) + solvable = i in l.highlight_u + matched = i == match + if !matched && solvable + :default + elseif !matched && !solvable + :light_black + elseif matched && solvable + :light_yellow + elseif matched && !solvable + :magenta + end + end + if !isempty(setdiff(l.highlight_u, l.u)) + # Only for debugging, shouldn't happen in practice + print(io, + map(union(l.u, l.highlight_u)) do i + HighlightInt(i, !(i in l.u) ? :light_red : choose_color(i), + i == match) + end) + else + print(io, map(l.u) do i + HighlightInt(i, choose_color(i), i == match) + end) + end + end +end + +struct Label + s::String + c::Symbol +end +Label(s::AbstractString) = Label(s, :nothing) +Label(x::Integer) = Label(string(x)) +Base.show(io::IO, l::Label) = printstyled(io, l.s, color = l.c) + +struct BipartiteGraphPrintMatrix <: + AbstractMatrix{Union{Label, Int, BipartiteAdjacencyList}} + bpg::BipartiteGraph +end +Base.size(bgpm::BipartiteGraphPrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 3) +function Base.getindex(bgpm::BipartiteGraphPrintMatrix, i::Integer, j::Integer) + checkbounds(bgpm, i, j) + if i == 1 + return (Label.(("#", "src", "dst")))[j] + elseif j == 1 + return i - 1 + elseif j == 2 + return BipartiteAdjacencyList(i - 1 <= nsrcs(bgpm.bpg) ? + 𝑠neighbors(bgpm.bpg, i - 1) : nothing) + elseif j == 3 + return BipartiteAdjacencyList(i - 1 <= ndsts(bgpm.bpg) ? + 𝑑neighbors(bgpm.bpg, i - 1) : nothing) + else + @assert false + end +end + +function Base.show(io::IO, b::BipartiteGraph) + print(io, "BipartiteGraph with (", length(b.fadjlist), ", ", + isa(b.badjlist, Int) ? b.badjlist : length(b.badjlist), ") (𝑠,𝑑)-vertices\n") + Base.print_matrix(io, BipartiteGraphPrintMatrix(b)) +end + +""" +```julia +Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T <: Integer} +``` + +Test whether two [`BipartiteGraph`](@ref)s are equal. +""" +function Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T <: Integer} + iseq = (bg1.ne == bg2.ne) + iseq &= (bg1.fadjlist == bg2.fadjlist) + iseq &= (bg1.badjlist == bg2.badjlist) + iseq +end + +""" +$(SIGNATURES) + +Build an empty `BipartiteGraph` with `nsrcs` sources and `ndsts` destinations. +""" +function BipartiteGraph(nsrcs::T, ndsts::T, backedge::Val{B} = Val(true); + metadata = nothing) where {T, B} + fadjlist = map(_ -> T[], 1:nsrcs) + badjlist = B ? map(_ -> T[], 1:ndsts) : ndsts + BipartiteGraph(0, fadjlist, badjlist, metadata) +end + +function Base.copy(bg::BipartiteGraph) + BipartiteGraph(bg.ne, map(copy, bg.fadjlist), map(copy, bg.badjlist), + deepcopy(bg.metadata)) +end +Base.eltype(::Type{<:BipartiteGraph{I}}) where {I} = I +function Base.empty!(g::BipartiteGraph) + foreach(empty!, g.fadjlist) + g.badjlist isa AbstractVector && foreach(empty!, g.badjlist) + g.ne = 0 + if g.metadata !== nothing + foreach(empty!, g.metadata) + end + g +end +Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.") + +if isdefined(Graphs, :has_contiguous_vertices) + Graphs.has_contiguous_vertices(::Type{<:BipartiteGraph}) = false +end +Graphs.is_directed(::Type{<:BipartiteGraph}) = false +Graphs.vertices(g::BipartiteGraph) = (𝑠vertices(g), 𝑑vertices(g)) +𝑠vertices(g::BipartiteGraph) = axes(g.fadjlist, 1) +function 𝑑vertices(g::BipartiteGraph) + g.badjlist isa AbstractVector ? axes(g.badjlist, 1) : Base.OneTo(g.badjlist) +end +has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g) +has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g) +function 𝑠neighbors(g::BipartiteGraph, i::Integer, + with_metadata::Val{M} = Val(false)) where {M} + M ? zip(g.fadjlist[i], g.metadata[i]) : g.fadjlist[i] +end +function 𝑑neighbors(g::BipartiteGraph, j::Integer, + with_metadata::Val{M} = Val(false)) where {M} + require_complete(g) + M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j] +end +Graphs.ne(g::BipartiteGraph) = g.ne +Graphs.nv(g::BipartiteGraph) = sum(length, vertices(g)) +Graphs.edgetype(g::BipartiteGraph{I}) where {I} = BipartiteEdge{I} + +nsrcs(g::BipartiteGraph) = length(𝑠vertices(g)) +ndsts(g::BipartiteGraph) = length(𝑑vertices(g)) + +function Graphs.has_edge(g::BipartiteGraph, edge::BipartiteEdge) + @unpack src, dst = edge + (src in 𝑠vertices(g) && dst in 𝑑vertices(g)) || return false # edge out of bounds + insorted(dst, 𝑠neighbors(g, src)) +end +Base.in(edge::BipartiteEdge, g::BipartiteGraph) = Graphs.has_edge(g, edge) + +### Maximal matching +""" + construct_augmenting_path!(m::Matching, g::BipartiteGraph, vsrc, dstfilter, vcolor=falses(ndsts(g)), ecolor=nothing) -> path_found::Bool + +Try to construct an augmenting path in matching and if such a path is found, +update the matching accordingly. +""" +function construct_augmenting_path!(matching::Matching, g::BipartiteGraph, vsrc, dstfilter, + dcolor = falses(ndsts(g)), scolor = nothing) + scolor === nothing || (scolor[vsrc] = true) + + # if a `vdst` is unassigned and the edge `vsrc <=> vdst` exists + for vdst in 𝑠neighbors(g, vsrc) + if dstfilter(vdst) && matching[vdst] === unassigned + matching[vdst] = vsrc + return true + end + end + + # for every `vsrc` such that edge `vsrc <=> vdst` exists and `vdst` is uncolored + for vdst in 𝑠neighbors(g, vsrc) + (dstfilter(vdst) && !dcolor[vdst]) || continue + dcolor[vdst] = true + if construct_augmenting_path!(matching, g, matching[vdst], dstfilter, dcolor, + scolor) + matching[vdst] = vsrc + return true + end + end + return false +end + +""" + maximal_matching(g::BipartiteGraph, [srcfilter], [dstfilter]) + +For a bipartite graph `g`, construct a maximal matching of destination to source +vertices, subject to the constraint that vertices for which `srcfilter` or `dstfilter`, +return `false` may not be matched. +""" +function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true, + dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U} + matching = Matching{U}(max(nsrcs(g), ndsts(g))) + foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc + construct_augmenting_path!(matching, g, vsrc, dstfilter) + end + return matching +end + +### +### Populate +### +struct NoMetadata end +const NO_METADATA = NoMetadata() + +function Graphs.add_edge!(g::BipartiteGraph, i::Integer, j::Integer, md = NO_METADATA) + add_edge!(g, BipartiteEdge(i, j), md) +end +function Graphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md = NO_METADATA) + @unpack fadjlist, badjlist = g + s, d = src(edge), dst(edge) + (has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.") + @inbounds list = fadjlist[s] + index = searchsortedfirst(list, d) + @inbounds (index <= length(list) && list[index] == d) && return false # edge already in graph + insert!(list, index, d) + if md !== NO_METADATA + insert!(g.metadata[s], index, md) + end + + g.ne += 1 + if badjlist isa AbstractVector + @inbounds list = badjlist[d] + index = searchsortedfirst(list, s) + insert!(list, index, s) + end + return true # edge successfully added +end + +function Graphs.rem_edge!(g::BipartiteGraph, i::Integer, j::Integer) + Graphs.rem_edge!(g, BipartiteEdge(i, j)) +end +function Graphs.rem_edge!(g::BipartiteGraph, edge::BipartiteEdge) + @unpack fadjlist, badjlist = g + s, d = src(edge), dst(edge) + (has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.") + @inbounds list = fadjlist[s] + index = searchsortedfirst(list, d) + @inbounds (index <= length(list) && list[index] == d) || + error("graph does not have edge $edge") + deleteat!(list, index) + g.ne -= 1 + if badjlist isa AbstractVector + @inbounds list = badjlist[d] + index = searchsortedfirst(list, s) + deleteat!(list, index) + end + return true # edge successfully deleted +end + +function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where {T} + if type === DST + if g.badjlist isa AbstractVector + push!(g.badjlist, T[]) + return length(g.badjlist) + else + g.badjlist += 1 + return g.badjlist + end + elseif type === SRC + push!(g.fadjlist, T[]) + return length(g.fadjlist) + else + error("type ($type) must be either `DST` or `SRC`") + end +end + +function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors) + old_neighbors = g.fadjlist[i] + old_nneighbors = length(old_neighbors) + new_nneighbors = length(new_neighbors) + g.ne += new_nneighbors - old_nneighbors + if isa(g.badjlist, AbstractVector) + for n in old_neighbors + @inbounds list = g.badjlist[n] + index = searchsortedfirst(list, i) + if 1 <= index <= length(list) && list[index] == i + deleteat!(list, index) + end + end + for n in new_neighbors + @inbounds list = g.badjlist[n] + index = searchsortedfirst(list, i) + if !(1 <= index <= length(list) && list[index] == i) + insert!(list, index, i) + end + end + end + if iszero(new_nneighbors) # this handles Tuple as well + # Warning: Aliases old_neighbors + empty!(g.fadjlist[i]) + else + g.fadjlist[i] = unique!(sort(new_neighbors)) + end +end + +function delete_srcs!(g::BipartiteGraph, srcs) + for s in srcs + set_neighbors!(g, s, ()) + end + g +end +delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs) + +### +### Edges iteration +### +Graphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC)) +𝑠edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC)) +𝑑edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(DST)) + +struct BipartiteEdgeIter{T, G} <: Graphs.AbstractEdgeIter + g::G + type::Val{T} +end + +Base.length(it::BipartiteEdgeIter) = ne(it.g) +Base.eltype(it::BipartiteEdgeIter) = edgetype(it.g) + +function Base.iterate(it::BipartiteEdgeIter{SRC, <:BipartiteGraph{T}}, + state = (1, 1, SRC)) where {T} + @unpack g = it + neqs = nsrcs(g) + neqs == 0 && return nothing + eq, jvar = state + + while eq <= neqs + eq′ = eq + vars = 𝑠neighbors(g, eq′) + if jvar > length(vars) + eq += 1 + jvar = 1 + continue + end + edge = BipartiteEdge(eq′, vars[jvar]) + state = (eq, jvar + 1, SRC) + return edge, state + end + return nothing +end + +function Base.iterate(it::BipartiteEdgeIter{DST, <:BipartiteGraph{T}}, + state = (1, 1, DST)) where {T} + @unpack g = it + nvars = ndsts(g) + nvars == 0 && return nothing + ieq, jvar = state + + while jvar <= nvars + eqs = 𝑑neighbors(g, jvar) + if ieq > length(eqs) + ieq = 1 + jvar += 1 + continue + end + edge = BipartiteEdge(eqs[ieq], jvar) + state = (ieq + 1, jvar, DST) + return edge, state + end + return nothing +end + +### +### Utils +### +function Graphs.incidence_matrix(g::BipartiteGraph, val = true) + I = Int[] + J = Int[] + for i in 𝑠vertices(g), n in 𝑠neighbors(g, i) + push!(I, i) + push!(J, n) + end + S = sparse(I, J, val, nsrcs(g), ndsts(g)) +end + +""" + struct DiCMOBiGraph + +This data structure implements a "directed, contracted, matching-oriented" view of an +original (undirected) bipartite graph. It has two modes, depending on the `Transposed` +flag, which switches the direction of the induced matching. + +Essentially the graph adapter performs two largely orthogonal functions +[`Transposed == true` differences are indicated in square brackets]: + + 1. It pairs an undirected bipartite graph with a matching of the destination vertex. + + This matching is used to induce an orientation on the otherwise undirected graph: + Matched edges pass from destination to source [source to destination], all other edges + pass in the opposite direction. + + 2. It exposes the graph view obtained by contracting the destination [source] vertices + along the matched edges. + +The result of this operation is an induced, directed graph on the source [destination] vertices. +The resulting graph has a few desirable properties. In particular, this graph +is acyclic if and only if the induced directed graph on the original bipartite +graph is acyclic. + +# Hypergraph interpretation + +Consider the bipartite graph `B` as the incidence graph of some hypergraph `H`. +Note that a matching `M` on `B` in the above sense is equivalent to determining +an (1,n)-orientation on the hypergraph (i.e. each directed hyperedge has exactly +one head, but any arbitrary number of tails). In this setting, this is simply +the graph formed by expanding each directed hyperedge into `n` ordinary edges +between the same vertices. +""" +mutable struct DiCMOBiGraph{Transposed, I, G <: BipartiteGraph{I}, M <: Matching} <: + Graphs.AbstractGraph{I} + graph::G + ne::Union{Missing, Int} + matching::M + function DiCMOBiGraph{Transposed}(g::G, ne::Union{Missing, Int}, + m::M) where {Transposed, I, G <: BipartiteGraph{I}, M} + new{Transposed, I, G, M}(g, ne, m) + end +end +function DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed} + DiCMOBiGraph{Transposed}(g, 0, Matching(ndsts(g))) +end +function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M} + DiCMOBiGraph{Transposed}(g, missing, m) +end + +function invview(g::DiCMOBiGraph{Transposed}) where {Transposed} + DiCMOBiGraph{!Transposed}(invview(g.graph), g.ne, invview(g.matching)) +end + +Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true +function Graphs.nv(g::DiCMOBiGraph{Transposed}) where {Transposed} + Transposed ? ndsts(g.graph) : nsrcs(g.graph) +end +function Graphs.vertices(g::DiCMOBiGraph{Transposed}) where {Transposed} + Transposed ? 𝑑vertices(g.graph) : 𝑠vertices(g.graph) +end + +struct CMONeighbors{Transposed, V} + g::DiCMOBiGraph{Transposed} + v::V + function CMONeighbors{Transposed}(g::DiCMOBiGraph{Transposed}, + v::V) where {Transposed, V} + new{Transposed, V}(g, v) + end +end + +Graphs.outneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{false}(g, v) +Graphs.inneighbors(g::DiCMOBiGraph{false}, v) = inneighbors(invview(g), v) +Base.iterate(c::CMONeighbors{false}) = iterate(c, (c.g.graph.fadjlist[c.v],)) +function Base.iterate(c::CMONeighbors{false}, (l, state...)) + while true + r = iterate(l, state...) + r === nothing && return nothing + # If this is a matched edge, skip it, it's reversed in the induced + # directed graph. Otherwise, if there is no matching for this destination + # edge, also skip it, since it got deleted in the contraction. + vsrc = c.g.matching[r[1]] + if vsrc === c.v || !isa(vsrc, Int) + state = (r[2],) + continue + end + return vsrc, (l, r[2]) + end +end +Base.length(c::CMONeighbors{false}) = count(_ -> true, c) + +liftint(f, x) = (!isa(x, Int)) ? nothing : f(x) +liftnothing(f, x) = x === nothing ? nothing : f(x) + +_vsrc(c::CMONeighbors{true}) = c.g.matching[c.v] +_neighbors(c::CMONeighbors{true}) = liftint(vsrc -> c.g.graph.fadjlist[vsrc], _vsrc(c)) +Base.length(c::CMONeighbors{true}) = something(liftnothing(length, _neighbors(c)), 1) - 1 +Graphs.inneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{true}(g, v) +Graphs.outneighbors(g::DiCMOBiGraph{true}, v) = outneighbors(invview(g), v) +Base.iterate(c::CMONeighbors{true}) = liftnothing(ns -> iterate(c, (ns,)), _neighbors(c)) +function Base.iterate(c::CMONeighbors{true}, (l, state...)) + while true + r = iterate(l, state...) + r === nothing && return nothing + if r[1] === c.v + state = (r[2],) + continue + end + return r[1], (l, r[2]) + end +end + +function _edges(g::DiCMOBiGraph{Transposed}) where {Transposed} + Transposed ? + ((w => v for w in inneighbors(g, v)) for v in vertices(g)) : + ((v => w for w in outneighbors(g, v)) for v in vertices(g)) +end + +Graphs.edges(g::DiCMOBiGraph) = (Graphs.SimpleEdge(p) for p in Iterators.flatten(_edges(g))) +function Graphs.ne(g::DiCMOBiGraph) + if g.ne === missing + g.ne = mapreduce(x -> length(x.iter), +, _edges(g)) + end + return g.ne +end + +Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b) +Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a) +# This definition is required for `induced_subgraph` to work +(::Type{<:DiCMOBiGraph})(n::Integer) = SimpleDiGraph(n) + +# Condensation Graphs +abstract type AbstractCondensationGraph <: AbstractGraph{Int} end +function (T::Type{<:AbstractCondensationGraph})(g, sccs::Vector{Union{Int, Vector{Int}}}) + scc_assignment = Vector{Int}(undef, isa(g, BipartiteGraph) ? ndsts(g) : nv(g)) + for (i, c) in enumerate(sccs) + for v in c + scc_assignment[v] = i + end + end + T(g, sccs, scc_assignment) +end +function (T::Type{<:AbstractCondensationGraph})(g, sccs::Vector{Vector{Int}}) + T(g, Vector{Union{Int, Vector{Int}}}(sccs)) +end + +Graphs.is_directed(::Type{<:AbstractCondensationGraph}) = true +Graphs.nv(icg::AbstractCondensationGraph) = length(icg.sccs) +Graphs.vertices(icg::AbstractCondensationGraph) = Base.OneTo(nv(icg)) + +""" + struct MatchedCondensationGraph + +For some bipartite-graph and an orientation induced on its destination contraction, +records the condensation DAG of the digraph formed by the orientation. I.e. this +is a DAG of connected components formed by the destination vertices of some +underlying bipartite graph. +N.B.: This graph does not store explicit neighbor relations of the sccs. +Therefor, the edge multiplicity is derived from the underlying bipartite graph, +i.e. this graph is not strict. +""" +struct MatchedCondensationGraph{G <: DiCMOBiGraph} <: AbstractCondensationGraph + graph::G + # Records the members of a strongly connected component. For efficiency, + # trivial sccs (with one vertex member) are stored inline. Note: the sccs + # here need not be stored in topological order. + sccs::Vector{Union{Int, Vector{Int}}} + # Maps the vertices back to the scc of which they are a part + scc_assignment::Vector{Int} +end + +function Graphs.outneighbors(mcg::MatchedCondensationGraph, cc::Integer) + Iterators.flatten((mcg.scc_assignment[v′] + for v′ in outneighbors(mcg.graph, v) if mcg.scc_assignment[v′] != cc) + for v in mcg.sccs[cc]) +end + +function Graphs.inneighbors(mcg::MatchedCondensationGraph, cc::Integer) + Iterators.flatten((mcg.scc_assignment[v′] + for v′ in inneighbors(mcg.graph, v) if mcg.scc_assignment[v′] != cc) + for v in mcg.sccs[cc]) +end + +""" + struct InducedCondensationGraph + +For some bipartite-graph and a topologicall sorted list of connected components, +represents the condensation DAG of the digraph formed by the orientation. I.e. this +is a DAG of connected components formed by the destination vertices of some +underlying bipartite graph. +N.B.: This graph does not store explicit neighbor relations of the sccs. +Therefor, the edge multiplicity is derived from the underlying bipartite graph, +i.e. this graph is not strict. +""" +struct InducedCondensationGraph{G <: BipartiteGraph} <: AbstractCondensationGraph + graph::G + # Records the members of a strongly connected component. For efficiency, + # trivial sccs (with one vertex member) are stored inline. Note: the sccs + # here are stored in topological order. + sccs::Vector{Union{Int, Vector{Int}}} + # Maps the vertices back to the scc of which they are a part + scc_assignment::Vector{Int} +end + +function _neighbors(icg::InducedCondensationGraph, cc::Integer) + Iterators.flatten(Iterators.flatten(icg.graph.fadjlist[vsrc] + for vsrc in icg.graph.badjlist[v]) + for v in icg.sccs[cc]) +end + +function Graphs.outneighbors(icg::InducedCondensationGraph, v::Integer) + (icg.scc_assignment[n] for n in _neighbors(icg, v) if icg.scc_assignment[n] > v) +end + +function Graphs.inneighbors(icg::InducedCondensationGraph, v::Integer) + (icg.scc_assignment[n] for n in _neighbors(icg, v) if icg.scc_assignment[n] < v) +end + +end # module diff --git a/src/graph/diff.jl b/src/graph/diff.jl new file mode 100644 index 0000000..eb7e8a1 --- /dev/null +++ b/src/graph/diff.jl @@ -0,0 +1,102 @@ +struct DiffGraph <: Graphs.AbstractGraph{Int} + primal_to_diff::Vector{Union{Int, Nothing}} + diff_to_primal::Union{Nothing, Vector{Union{Int, Nothing}}} +end + +DiffGraph(primal_to_diff::Vector{Union{Int, Nothing}}) = DiffGraph(primal_to_diff, nothing) +function DiffGraph(n::Integer, with_badj::Bool = false) + DiffGraph(Union{Int, Nothing}[nothing for _ in 1:n], + with_badj ? Union{Int, Nothing}[nothing for _ in 1:n] : nothing) +end + +function Base.copy(dg::DiffGraph) + DiffGraph(copy(dg.primal_to_diff), + dg.diff_to_primal === nothing ? nothing : copy(dg.diff_to_primal)) +end + +@noinline function require_complete(dg::DiffGraph) + dg.diff_to_primal === nothing && + error("Not complete. Run `complete` first.") +end + +Graphs.is_directed(dg::DiffGraph) = true +function Graphs.edges(dg::DiffGraph) + (i => v for (i, v) in enumerate(dg.primal_to_diff) if v !== nothing) +end +Graphs.nv(dg::DiffGraph) = length(dg.primal_to_diff) +Graphs.ne(dg::DiffGraph) = count(x -> x !== nothing, dg.primal_to_diff) +Graphs.vertices(dg::DiffGraph) = Base.OneTo(nv(dg)) +function Graphs.outneighbors(dg::DiffGraph, var::Integer) + diff = dg.primal_to_diff[var] + return diff === nothing ? () : (diff,) +end +function Graphs.inneighbors(dg::DiffGraph, var::Integer) + require_complete(dg) + diff = dg.diff_to_primal[var] + return diff === nothing ? () : (diff,) +end +function Graphs.add_vertex!(dg::DiffGraph) + push!(dg.primal_to_diff, nothing) + if dg.diff_to_primal !== nothing + push!(dg.diff_to_primal, nothing) + end + return length(dg.primal_to_diff) +end + +function Graphs.add_edge!(dg::DiffGraph, var::Integer, diff::Integer) + dg[var] = diff +end + +# Also pass through the array interface for ease of use +Base.:(==)(dg::DiffGraph, v::AbstractVector) = dg.primal_to_diff == v +Base.:(==)(dg::AbstractVector, v::DiffGraph) = v == dg.primal_to_diff +Base.eltype(::DiffGraph) = Union{Int, Nothing} +Base.size(dg::DiffGraph) = size(dg.primal_to_diff) +Base.length(dg::DiffGraph) = length(dg.primal_to_diff) +Base.getindex(dg::DiffGraph, var::Integer) = dg.primal_to_diff[var] +Base.getindex(dg::DiffGraph, a::AbstractArray) = [dg[x] for x in a] + +function Base.setindex!(dg::DiffGraph, val::Union{Integer, Nothing}, var::Integer) + if dg.diff_to_primal !== nothing + old_pd = dg.primal_to_diff[var] + if old_pd !== nothing + dg.diff_to_primal[old_pd] = nothing + end + if val !== nothing + #old_dp = dg.diff_to_primal[val] + #old_dp === nothing || error("Variable already assigned.") + dg.diff_to_primal[val] = var + end + end + return dg.primal_to_diff[var] = val +end +Base.iterate(dg::DiffGraph, state...) = iterate(dg.primal_to_diff, state...) + +function complete(dg::DiffGraph) + dg.diff_to_primal !== nothing && return dg + diff_to_primal = Union{Int, Nothing}[nothing for _ in 1:length(dg.primal_to_diff)] + for (var, diff) in edges(dg) + diff_to_primal[diff] = var + end + return DiffGraph(dg.primal_to_diff, diff_to_primal) +end + +function invview(dg::DiffGraph) + require_complete(dg) + return DiffGraph(dg.diff_to_primal, dg.primal_to_diff) +end + +struct DiffChainIterator{Descend} + var_to_diff::DiffGraph + v::Int +end + +function Base.iterate(di::DiffChainIterator{Descend}, v = nothing) where {Descend} + if v === nothing + vv = di.v + return (vv, vv) + end + g = Descend ? invview(di.var_to_diff) : di.var_to_diff + v′ = g[v] + v′ === nothing ? nothing : (v′, v′) +end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..1ae1fac --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,39 @@ +abstract type SystemStructure; end +is_only_discrete(::SystemStructure) = false + +abstract type TransformationState{T} end +abstract type AbstractTearingState{T} <: TransformationState{T} end + +struct SelectedState end + +function linear_subsys_adjmat! end +function eq_derivative! end +function var_derivative! end + +function eq_derivative_graph!(s::SystemStructure, eq::Int) + add_vertex!(s.graph, SRC) + s.solvable_graph === nothing || add_vertex!(s.solvable_graph, SRC) + # the new equation is created by differentiating `eq` + eq_diff = add_vertex!(s.eq_to_diff) + add_edge!(s.eq_to_diff, eq, eq_diff) + return eq_diff +end + +function var_derivative_graph!(s::SystemStructure, v::Int) + sg = g = add_vertex!(s.graph, DST) + var_diff = add_vertex!(s.var_to_diff) + add_edge!(s.var_to_diff, v, var_diff) + s.solvable_graph === nothing || (sg = add_vertex!(s.solvable_graph, DST)) + @assert sg == g == var_diff + return var_diff +end + +function complete!(s::SystemStructure) + s.var_to_diff = complete(s.var_to_diff) + s.eq_to_diff = complete(s.eq_to_diff) + s.graph = complete(s.graph) + if s.solvable_graph !== nothing + s.solvable_graph = complete(s.solvable_graph) + end + s +end diff --git a/src/math/bareiss.jl b/src/math/bareiss.jl new file mode 100644 index 0000000..c876036 --- /dev/null +++ b/src/math/bareiss.jl @@ -0,0 +1,271 @@ +# Keeps compatibility with bariess code moved to Base/stdlib on older releases +module bareiss + +using LinearAlgebra +using SparseArrays +using SparseArrays: AbstractSparseMatrixCSC, getcolptr +import Base: swaprows! + +export bareiss_zero!, bareiss!, bareiss_update_virtual_colswap_mtk!, exactdiv + +macro swap(a, b) + esc(:(($a, $b) = ($b, $a))) +end + +function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, + prev_pivot::Base.BitInteger) + flag = zero(prev_pivot) + prev_pivot = Base.MultiplicativeInverses.SignedMultiplicativeInverse(prev_pivot) + @inbounds for i in (k + 1):size(M, 2) + Mki = M[k, i] + @simd ivdep for j in (k + 1):size(M, 1) + M[j, i], r = divrem(M[j, i] * pivot - M[j, k] * Mki, prev_pivot) + flag = flag | r + end + end + iszero(flag) || error("Overflow occurred") + zero!(M, (k + 1):size(M, 1), k) +end + +function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, prev_pivot) + @inbounds for i in (k + 1):size(M, 2), j in (k + 1):size(M, 1) + M[j, i] = exactdiv(M[j, i] * pivot - M[j, k] * M[k, i], prev_pivot) + end + zero!(M, (k + 1):size(M, 1), k) +end + +@views function bareiss_update!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot) + if prev_pivot isa Base.BitInteger + prev_pivot = Base.MultiplicativeInverses.SignedMultiplicativeInverse(prev_pivot) + end + V = M[(k + 1):end, (k + 1):end] + V .= exactdiv.(V .* pivot .- M[(k + 1):end, k] * M[k, (k + 1):end]', prev_pivot) + zero!(M, (k + 1):size(M, 1), k) + if M isa AbstractSparseMatrixCSC + dropzeros!(M) + end +end + +function bareiss_update_virtual_colswap!(zero!, M::AbstractMatrix, k, swapto, pivot, + prev_pivot) + if prev_pivot isa Base.BitInteger + prev_pivot = Base.MultiplicativeInverses.SignedMultiplicativeInverse(prev_pivot) + end + V = @view M[(k + 1):end, :] + V .= @views exactdiv.(V .* pivot .- M[(k + 1):end, swapto[2]] * M[k, :]', prev_pivot) + zero!(M, (k + 1):size(M, 1), swapto[2]) +end + +bareiss_zero!(M, i, j) = M[i, j] .= zero(eltype(M)) + +function find_pivot_col(M, i) + p = findfirst(!iszero, @view M[i, i:end]) + p === nothing && return nothing + idx = CartesianIndex(i, p + i - 1) + (idx, M[idx]) +end + +function find_pivot_any(M, i) + p = findfirst(!iszero, @view M[i:end, i:end]) + p === nothing && return nothing + idx = p + CartesianIndex(i - 1, i - 1) + (idx, M[idx]) +end + +function exactdiv(a::Integer, b) + d, r = divrem(a, b) + @assert r == 0 + return d +end + +function bareiss_update_virtual_colswap_mtk!(zero!, M::AbstractMatrix, k, swapto, pivot, + last_pivot; pivot_equal_optimization = true) + if pivot_equal_optimization + error("MTK pivot micro-optimization not implemented for `$(typeof(M))`. + Turn off the optimization for debugging or use a different matrix type.") + end + bareiss.bareiss_update_virtual_colswap!(zero!, M, k, swapto, pivot, last_pivot) +end + +const bareiss_colswap = (Base.swapcols!, swaprows!, bareiss_update!, bareiss_zero!) +const bareiss_virtcolswap = ((M, i, j) -> nothing, swaprows!, + bareiss_update_virtual_colswap!, bareiss_zero!) + +""" + bareiss!(M, [swap_strategy]) + +Perform Bareiss's fraction-free row-reduction algorithm on the matrix `M`. +Optionally, a specific pivoting method may be specified. + +swap_strategy is an optional argument that determines how the swapping of rows and columns is performed. +bareiss_colswap (the default) swaps the columns and rows normally. +bareiss_virtcolswap pretends to swap the columns which can be faster for sparse matrices. +""" +function bareiss!(M::AbstractMatrix{T}, swap_strategy = bareiss_colswap; + find_pivot = find_pivot_any, column_pivots = nothing) where {T} + swapcols!, swaprows!, update!, zero! = swap_strategy + prev = one(eltype(M)) + n = size(M, 1) + pivot = one(T) + column_permuted = false + for k in 1:n + r = find_pivot(M, k) + r === nothing && return (k - 1, pivot, column_permuted) + (swapto, pivot) = r + if column_pivots !== nothing && k != swapto[2] + column_pivots[k] = swapto[2] + column_permuted |= true + end + if CartesianIndex(k, k) != swapto + swapcols!(M, k, swapto[2]) + swaprows!(M, k, swapto[1]) + end + update!(zero!, M, k, swapto, pivot, prev) + prev = pivot + end + return (n, pivot, column_permuted) +end + +function nullspace(A; col_order = nothing) + n = size(A, 2) + workspace = zeros(Int, 2 * n) + column_pivots = @view workspace[1:n] + pivots_cache = @view workspace[(n + 1):(2n)] + @inbounds for i in 1:n + column_pivots[i] = i + end + B = copy(A) + (rank, d, column_permuted) = bareiss!(B; column_pivots) + reduce_echelon!(B, rank, d, pivots_cache) + + # The first rank entries in col_order are columns that give a basis + # for the column space. The remainder give the free variables. + if col_order !== nothing + resize!(col_order, size(A, 2)) + col_order .= 1:size(A, 2) + for (i, cp) in enumerate(column_pivots) + @swap(col_order[i], col_order[cp]) + end + end + + fill!(pivots_cache, 0) + N = reduced_echelon_nullspace(rank, B, pivots_cache) + apply_inv_pivot_rows!(N, column_pivots) +end + +function apply_inv_pivot_rows!(M, ipiv) + for i in size(M, 1):-1:1 + swaprows!(M, i, ipiv[i]) + end + M +end + +### +### Modified from AbstractAlgebra.jl +### +### https://github.com/Nemocas/AbstractAlgebra.jl/blob/4803548c7a945f3f7bd8c63f8bb7c79fac92b11a/LICENSE.md +function reduce_echelon!(A::AbstractMatrix{T}, rank, d, + pivots_cache = zeros(Int, size(A, 2))) where {T} + m, n = size(A) + isreduced = true + @inbounds for i in 1:rank + for j in 1:(i - 1) + if A[j, i] != zero(T) + isreduced = false + @goto out + end + end + if A[i, i] != one(T) + isreduced = false + @goto out + end + end + @label out + @inbounds for i in (rank + 1):m, j in 1:n + A[i, j] = zero(T) + end + isreduced && return A + + @inbounds if rank > 1 + t = zero(T) + q = zero(T) + d = -d + pivots = pivots_cache + np = rank + j = k = 1 + for i in 1:rank + while iszero(A[i, j]) + pivots[np + k] = j + j += 1 + k += 1 + end + pivots[i] = j + j += 1 + end + while k <= n - rank + pivots[np + k] = j + j += 1 + k += 1 + end + for k in 1:(n - rank) + for i in (rank - 1):-1:1 + t = A[i, pivots[np + k]] * d + for j in (i + 1):rank + t += A[i, pivots[j]] * A[j, pivots[np + k]] + q + end + A[i, pivots[np + k]] = exactdiv(-t, A[i, pivots[i]]) + end + end + d = -d + for i in 1:rank + for j in 1:rank + if i == j + A[j, pivots[i]] = d + else + A[j, pivots[i]] = zero(T) + end + end + end + end + return A +end + +function reduced_echelon_nullspace(rank, A::AbstractMatrix{T}, + pivots_cache = zeros(Int, size(A, 2))) where {T} + n = size(A, 2) + nullity = n - rank + U = zeros(T, n, nullity) + @inbounds if rank == 0 + for i in 1:nullity + U[i, i] = one(T) + end + elseif nullity != 0 + pivots = @view pivots_cache[1:rank] + nonpivots = @view pivots_cache[(rank + 1):n] + j = k = 1 + for i in 1:rank + while iszero(A[i, j]) + nonpivots[k] = j + j += 1 + k += 1 + end + pivots[i] = j + j += 1 + end + while k <= nullity + nonpivots[k] = j + j += 1 + k += 1 + end + d = -A[1, pivots[1]] + for i in 1:nullity + for j in 1:rank + U[pivots[j], i] = A[j, nonpivots[i]] + end + U[nonpivots[i], i] = d + end + end + return U +end + +end diff --git a/src/math/sparsematrixclil.jl b/src/math/sparsematrixclil.jl new file mode 100644 index 0000000..bbc8e93 --- /dev/null +++ b/src/math/sparsematrixclil.jl @@ -0,0 +1,349 @@ +module CLIL + +using SparseArrays +using FindFirstFunctions: findfirstequal +using ..bareiss + +const _debug_mode = Base.JLOptions().check_bounds == 1 + +export SparseMatrixCLIL, nonzerosmap, CLILVector + +""" + SparseMatrixCLIL{T, Ti} + +The SparseMatrixCLIL represents a sparse matrix in two distinct ways: + +1. As a sparse (in both row and column) n x m matrix +2. As a row-dense, column-sparse k x m matrix + +The data structure keeps a permutation between the row order of the two representations. +Swapping the rows in one does not affect the other. + +On construction, the second representation is equivalent to the first with fully-sparse +rows removed, though this may cease being true as row permutations are being applied +to the matrix. + +The default structure of the `SparseMatrixCLIL` type is the second structure, while +the first is available via the thin `AsSubMatrix` wrapper. +""" +struct SparseMatrixCLIL{T, Ti <: Integer} <: AbstractSparseMatrix{T, Ti} + nparentrows::Int + ncols::Int + nzrows::Vector{Ti} + row_cols::Vector{Vector{Ti}} # issorted + row_vals::Vector{Vector{T}} +end +Base.size(S::SparseMatrixCLIL) = (length(S.nzrows), S.ncols) +function Base.copy(S::SparseMatrixCLIL{T, Ti}) where {T, Ti} + SparseMatrixCLIL(S.nparentrows, S.ncols, copy(S.nzrows), map(copy, S.row_cols), + map(copy, S.row_vals)) +end + +swap!(v, i, j) = v[i], v[j] = v[j], v[i] +function Base.swaprows!(S::SparseMatrixCLIL, i, j) + i == j && return + swap!(S.nzrows, i, j) + swap!(S.row_cols, i, j) + swap!(S.row_vals, i, j) +end + +function Base.convert(::Type{SparseMatrixCLIL{T, Ti}}, S::SparseMatrixCLIL) where {T, Ti} + return SparseMatrixCLIL(S.nparentrows, + S.ncols, + copy.(S.nzrows), + copy.(S.row_cols), + [T.(row) for row in S.row_vals]) +end + +function SparseMatrixCLIL(mm::AbstractMatrix) + nrows, ncols = size(mm) + row_cols = [findall(!iszero, row) for row in eachrow(mm)] + row_vals = [row[cols] for (row, cols) in zip(eachrow(mm), row_cols)] + SparseMatrixCLIL(nrows, ncols, Int[1:length(row_cols);], row_cols, row_vals) +end + +struct CLILVector{T, Ti} <: AbstractSparseVector{T, Ti} + vec::SparseVector{T, Ti} +end +Base.hash(v::CLILVector, s::UInt) = hash(v.vec, s) ⊻ 0xc71be0e9ccb75fbd +Base.size(v::CLILVector) = Base.size(v.vec) +Base.getindex(v::CLILVector, idx::Integer...) = Base.getindex(v.vec, idx...) +Base.setindex!(vec::CLILVector, v, idx::Integer...) = Base.setindex!(vec.vec, v, idx...) +function Base.view(a::SparseMatrixCLIL, i::Integer, ::Colon) + CLILVector(SparseVector(a.ncols, a.row_cols[i], a.row_vals[i])) +end +SparseArrays.nonzeroinds(a::CLILVector) = SparseArrays.nonzeroinds(a.vec) +SparseArrays.nonzeros(a::CLILVector) = SparseArrays.nonzeros(a.vec) +SparseArrays.nnz(a::CLILVector) = nnz(a.vec) + +function Base.setindex!(S::SparseMatrixCLIL, v::CLILVector, i::Integer, c::Colon) + if v.vec.n != S.ncols + throw(BoundsError(v, 1:(S.ncols))) + end + any(iszero, v.vec.nzval) && error("setindex failed") + S.row_cols[i] = copy(v.vec.nzind) + S.row_vals[i] = copy(v.vec.nzval) + return v +end + +zero!(a::AbstractArray{T}) where {T} = a[:] .= zero(T) +zero!(a::SparseVector) = (empty!(a.nzind); empty!(a.nzval)) +zero!(a::CLILVector) = zero!(a.vec) +SparseArrays.dropzeros!(a::CLILVector) = SparseArrays.dropzeros!(a.vec) + +struct NonZeros{T <: AbstractArray} + v::T +end +Base.pairs(nz::NonZeros{<:CLILVector}) = NonZerosPairs(nz.v) + +struct NonZerosPairs{T <: AbstractArray} + v::T +end + +Base.IteratorSize(::Type{<:NonZerosPairs}) = Base.SizeUnknown() +# N.B.: Because of how we're using this, this must be robust to modification of +# the underlying vector. As such, we treat this as an iteration over indices +# that happens to short cut using the sparse structure and sortedness of the +# array. +function Base.iterate(nzp::NonZerosPairs{<:CLILVector}, (idx, col)) + v = nzp.v.vec + nzind = v.nzind + nzval = v.nzval + if idx > length(nzind) + idx = length(col) + end + oldcol = nzind[idx] + if col != oldcol + # The vector was changed since the last iteration. Find our + # place in the vector again. + tail = col > oldcol ? (@view nzind[(idx + 1):end]) : (@view nzind[1:idx]) + tail_i = searchsortedfirst(tail, col + 1) + # No remaining indices. + tail_i > length(tail) && return nothing + new_idx = col > oldcol ? idx + tail_i : tail_i + new_col = nzind[new_idx] + return (new_col => nzval[new_idx], (new_idx, new_col)) + end + idx == length(nzind) && return nothing + new_col = nzind[idx + 1] + return (new_col => nzval[idx + 1], (idx + 1, new_col)) +end + +function Base.iterate(nzp::NonZerosPairs{<:CLILVector}) + v = nzp.v.vec + nzind = v.nzind + nzval = v.nzval + isempty(nzind) && return nothing + return nzind[1] => nzval[1], (1, nzind[1]) +end + +# Arguably this is how nonzeros should behave in the first place, but let's +# build something that works for us here and worry about it later. +nonzerosmap(a::CLILVector) = NonZeros(a) + +function bareiss.bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot, + last_pivot; pivot_equal_optimization = true) + # for ei in nzrows(>= k) + eadj = M.row_cols + old_cadj = M.row_vals + vpivot = swapto[2] + + ## N.B.: Micro-optimization + # + # For rows that do not have an entry in the eliminated column, all this + # update does is multiply the row in question by `pivot/last_pivot` (the + # result of which is guaranteed to be integer by general properties of the + # bareiss algorithm, even if `pivot/last_pivot` is not). + # + # Thus, when `pivot == last pivot`, we can skip the update for any rows that + # do not have an entry in the eliminated column (because we'd simply be + # multiplying by 1). + # + # As an additional MTK-specific enhancement, we further allow the case + # when the absolute values are equal, i.e. effectively multiplying the row + # by `-1`. To ensure this is legal, we need to show two things. + # 1. The multiplication does not change the answer and + # 2. The multiplication does not affect the fraction-freeness of the Bareiss + # algorithm. + # + # For point 1, remember that we're working on a system of linear equations, + # so it is always legal for us to multiply any row by a scalar without changing + # the underlying system of equations. + # + # For point 2, note that the factorization we're now computing is the same + # as if we had multiplied the corresponding row (accounting for row swaps) + # in the original matrix by `last_pivot/pivot`, ensuring that the matrix + # itself has integral entries when `last_pivot/pivot` is integral (here we + # have -1, which counts). We could use the more general integrality + # condition, but that would in turn disturb the growth bounds on the + # factorization matrix entries that the bareiss algorithm guarantees. To be + # conservative, we leave it at this, as this captures the most important + # case for MTK (where most pivots are `1` or `-1`). + pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot) + @inbounds for ei in (k + 1):size(M, 1) + # eliminate `v` + coeff = 0 + ivars = eadj[ei] + vj = findfirstequal(vpivot, ivars) + if vj !== nothing + coeff = old_cadj[ei][vj] + deleteat!(old_cadj[ei], vj) + deleteat!(eadj[ei], vj) + elseif pivot_equal + continue + end + + # the pivot row + kvars = eadj[k] + kcoeffs = old_cadj[k] + # the elimination target + ivars = eadj[ei] + icoeffs = old_cadj[ei] + + numkvars = length(kvars) + numivars = length(ivars) + tmp_incidence = similar(eadj[ei], numkvars + numivars) + tmp_coeffs = similar(old_cadj[ei], numkvars + numivars) + tmp_len = 0 + kvind = ivind = 0 + if _debug_mode + # in debug mode, we at least check to confirm we're iterating over + # `v`s in the correct order + vars = sort(union(ivars, kvars)) + vi = 0 + end + if numivars > 0 && numkvars > 0 + kvv = kvars[kvind += 1] + ivv = ivars[ivind += 1] + dobreak = false + while true + if kvv == ivv + v = kvv + ck = kcoeffs[kvind] + ci = icoeffs[ivind] + kvind += 1 + ivind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + p1 = Base.Checked.checked_mul(pivot, ci) + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) + elseif kvv < ivv + v = kvv + ck = kcoeffs[kvind] + kvind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) + else # kvv > ivv + v = ivv + ci = icoeffs[ivind] + ivind += 1 + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot) + end + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot && !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + dobreak && break + end + elseif numkvars > 0 + ivind = 1 + kvv = kvars[kvind += 1] + elseif numivars > 0 + kvind = 1 + ivv = ivars[ivind += 1] + end + if kvind <= numkvars + v = kvv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + ck = kcoeffs[kvind] + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) + if !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + end + (kvind == numkvars) && break + v = kvars[kvind += 1] + end + elseif ivind <= numivars + v = ivv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind]) + ci = exactdiv(p1, last_pivot) + if !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + end + (ivind == numivars) && break + v = ivars[ivind += 1] + end + end + resize!(tmp_incidence, tmp_len) + resize!(tmp_coeffs, tmp_len) + eadj[ei] = tmp_incidence + old_cadj[ei] = tmp_coeffs + end +end + +struct AsSubMatrix{T, Ti <: Integer} <: AbstractSparseMatrix{T, Ti} + M::SparseMatrixCLIL{T, Ti} +end +Base.size(S::AsSubMatrix) = (S.M.nparentrows, S.M.ncols) + +function Base.getindex(S::SparseMatrixCLIL{T}, i1::Integer, i2::Integer) where {T} + checkbounds(S, i1, i2) + + col = S.row_cols[i1] + nncol = searchsortedfirst(col, i2) + (nncol > length(col) || col[nncol] != i2) && return zero(T) + + return S.row_vals[i1][nncol] +end + +function Base.getindex(S::AsSubMatrix{T}, i1::Integer, i2::Integer) where {T} + checkbounds(S, i1, i2) + S = S.M + + nnrow = findfirst(==(i1), S.nzrows) + isnothing(nnrow) && return zero(T) + + col = S.row_cols[nnrow] + nncol = searchsortedfirst(col, i2) + (nncol > length(col) || col[nncol] != i2) && return zero(T) + + return S.row_vals[nnrow][nncol] +end + +end diff --git a/src/modia_tearing.jl b/src/modia_tearing.jl new file mode 100644 index 0000000..97d3b45 --- /dev/null +++ b/src/modia_tearing.jl @@ -0,0 +1,115 @@ +# This code is derived from the Modia project and is licensed as follows: +# https://github.com/ModiaSim/Modia.jl/blob/b61daad643ef7edd0c1ccce6bf462c6acfb4ad1a/LICENSE + +function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer) + G = ict.graph + add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G + G.matching[vj] = eq + G.ne += length(𝑠neighbors(G.graph, eq)) - 1 + end +end + +function try_assign_eq!(ict::IncrementalCycleTracker, vars, v_active, eq::Integer, + condition::F = _ -> true) where {F} + G = ict.graph + for vj in vars + (vj in v_active && G.matching[vj] === unassigned && condition(vj)) || continue + try_assign_eq!(ict, vj, eq) && return true + end + return false +end + +function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}, + v_active::BitSet, isder′::F) where {F} + check_der = isder′ !== nothing + if check_der + has_der = Ref(false) + isder = let has_der = has_der, isder′ = isder′ + v -> begin + r = isder′(v) + has_der[] |= r + r + end + end + end + # Heuristic: As a first pass, try to assign any equations that only have one + # solvable variable. + for only_single_solvable in (true, false) + for eq in es # iterate only over equations that are not in eSolvedFixed + vs = Gsolvable[eq] + ((length(vs) == 1) ⊻ only_single_solvable) && continue + if check_der + # if there're differentiated variables, then only consider them + try_assign_eq!(ict, vs, v_active, eq, isder) + if has_der[] + has_der[] = false + continue + end + end + try_assign_eq!(ict, vs, v_active, eq) + end + end + + return ict +end + +function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars, + isder::F) where {F} + tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder) + for var in vars + var_eq_matching[var] = ict.graph.matching[var] + end + return nothing +end + +function tear_graph_modia(structure::SystemStructure, isder::F = nothing, + ::Type{U} = Unassigned; + varfilter::F2 = v -> true, + eqfilter::F3 = eq -> true) where {F, U, F2, F3} + # It would be possible here to simply iterate over all variables and attempt to + # use tearEquations! to produce a matching that greedily selects the minimal + # number of torn variables. However, we can do this process faster if we first + # compute the strongly connected components. In the absence of cycles and + # non-solvability, a maximal matching on the original graph will give us an + # optimal assignment. However, even with cycles, we can use the maximal matching + # to give us a good starting point for a good matching and then proceed to + # reverse edges in each scc to improve the solution. Note that it is possible + # to have optimal solutions that cannot be found by this process. We will not + # find them here [TODO: It would be good to have an explicit example of this.] + + @unpack graph, solvable_graph = structure + var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U) + var_eq_matching = complete(var_eq_matching, + max(length(var_eq_matching), + maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))) + full_var_eq_matching = copy(var_eq_matching) + var_sccs = find_var_sccs(graph, var_eq_matching) + vargraph = DiCMOBiGraph{true}(graph) + ict = IncrementalCycleTracker(vargraph; dir = :in) + + ieqs = Int[] + filtered_vars = BitSet() + for vars in var_sccs + for var in vars + if varfilter(var) + push!(filtered_vars, var) + if var_eq_matching[var] !== unassigned + push!(ieqs, var_eq_matching[var]) + end + end + var_eq_matching[var] = unassigned + end + tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, ieqs, + filtered_vars, + isder) + + # clear cache + vargraph.ne = 0 + for var in vars + vargraph.matching[var] = unassigned + end + empty!(ieqs) + empty!(filtered_vars) + end + return var_eq_matching, full_var_eq_matching, var_sccs +end diff --git a/src/pantelides.jl b/src/pantelides.jl new file mode 100644 index 0000000..e7e8a36 --- /dev/null +++ b/src/pantelides.jl @@ -0,0 +1,136 @@ +using .BipartiteGraphs: 𝑑neighbors, 𝑠neighbors, nsrcs, ndsts, + construct_augmenting_path!, unassigned, DiCMOBiGraph + +""" + computed_highest_diff_variables(structure) + +Computes which variables are the "highest-differentiated" for purposes of +pantelides. Ordinarily this is relatively straightforward. However, in our +case, there is one complicating condition: + + We allow variables in the structure graph that don't appear in the + system at all. What we are interested in is the highest-differentiated + variable that actually appears in the system. + +This function takes care of these complications are returns a boolean array +for every variable, indicating whether it is considered "highest-differentiated". +""" +function computed_highest_diff_variables(structure) + @unpack graph, var_to_diff = structure + + nvars = length(var_to_diff) + varwhitelist = falses(nvars) + for var in 1:nvars + if var_to_diff[var] === nothing && !varwhitelist[var] + # This variable is structurally highest-differentiated, but may not actually appear in the + # system (complication 1 above). Ascend the differentiation graph to find the highest + # differentiated variable that does appear in the system or the alias graph). + while isempty(𝑑neighbors(graph, var)) + var′ = invview(var_to_diff)[var] + var′ === nothing && break + var = var′ + end + varwhitelist[var] = true + end + end + + # Remove any variables from the varwhitelist for whom a higher-differentiated + # var is already on the whitelist. + for var in 1:nvars + varwhitelist[var] || continue + var′ = var + while (var′ = var_to_diff[var′]) !== nothing + if varwhitelist[var′] + varwhitelist[var] = false + break + end + end + end + + return varwhitelist +end + +""" + pantelides!(state::TransformationState; kwargs...) + +Perform Pantelides algorithm. +""" +function pantelides!(state::TransformationState; finalize = true, maxiters = 8000) + @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure + neqs = nsrcs(graph) + nvars = nv(var_to_diff) + vcolor = falses(nvars) + ecolor = falses(neqs) + var_eq_matching = Matching(nvars) + neqs′ = neqs + nnonemptyeqs = count( + eq -> !isempty(𝑠neighbors(graph, eq)) && eq_to_diff[eq] === nothing, + 1:neqs′) + + varwhitelist = computed_highest_diff_variables(state.structure) + + if nnonemptyeqs > count(varwhitelist) + throw(InvalidSystemException("System is structurally singular")) + end + + for k in 1:neqs′ + eq′ = k + eq_to_diff[eq′] === nothing || continue + isempty(𝑠neighbors(graph, eq′)) && continue + pathfound = false + # In practice, `maxiters=8000` should never be reached, otherwise, the + # index would be on the order of thousands. + for iii in 1:maxiters + # run matching on (dx, y) variables + # + # the derivatives and algebraic variables are zeros in the variable + # association list + resize!(vcolor, nvars) + fill!(vcolor, false) + resize!(ecolor, neqs) + fill!(ecolor, false) + pathfound = construct_augmenting_path!(var_eq_matching, graph, eq′, + v -> varwhitelist[v], vcolor, ecolor) + pathfound && break # terminating condition + if is_only_discrete(state.structure) + error("The discrete system has high structural index. This is not supported.") + end + for var in eachindex(vcolor) + vcolor[var] || continue + if var_to_diff[var] === nothing + # introduce a new variable + nvars += 1 + var_diff = var_derivative!(state, var) + push!(var_eq_matching, unassigned) + push!(varwhitelist, false) + @assert length(var_eq_matching) == var_diff + end + varwhitelist[var] = false + varwhitelist[var_to_diff[var]] = true + end + + for eq in eachindex(ecolor) + ecolor[eq] || continue + # introduce a new equation + neqs += 1 + eq_derivative!(state, eq) + end + + for var in eachindex(vcolor) + vcolor[var] || continue + # the newly introduced `var`s and `eq`s have the inherits + # assignment + var_eq_matching[var_to_diff[var]] = eq_to_diff[var_eq_matching[var]] + end + eq′ = eq_to_diff[eq′] + end # for _ in 1:maxiters + pathfound || + error("maxiters=$maxiters reached! File a bug report if your system has a reasonable index (<100), and you are using the default `maxiters`. Try to increase the maxiters by `pantelides(sys::ODESystem; maxiters=1_000_000)` if your system has an incredibly high index and it is truly extremely large.") + end # for k in 1:neqs′ + + finalize && for var in 1:ndsts(graph) + varwhitelist[var] && continue + var_eq_matching[var] = unassigned + end + return var_eq_matching +end diff --git a/src/partial_state_selection.jl b/src/partial_state_selection.jl new file mode 100644 index 0000000..d2e28b8 --- /dev/null +++ b/src/partial_state_selection.jl @@ -0,0 +1,360 @@ +using .BipartiteGraphs: Unassigned, maximal_matching + +function partial_state_selection_graph!(state::TransformationState) + var_eq_matching = complete(pantelides!(state)) + complete!(state.structure) + partial_state_selection_graph!(state.structure, var_eq_matching) +end + +function ascend_dg(xs, dg, level) + while level > 0 + xs = Int[dg[x] for x in xs] + level -= 1 + end + return xs +end + +function ascend_dg_all(xs, dg, level, maxlevel) + r = Int[] + while true + if level <= 0 + append!(r, xs) + end + maxlevel <= 0 && break + xs = Int[dg[x] for x in xs if dg[x] !== nothing] + level -= 1 + maxlevel -= 1 + end + return r +end + +function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel, + inv_varlevel, inv_eqlevel) + @unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure + + # var_eq_matching is a maximal matching on the top-differentiated variables. + # Find Strongly connected components. Note that after pantelides, we expect + # a balanced system, so a maximal matching should be possible. + var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, maximal_top_matching) + var_eq_matching = Matching{Union{Unassigned, SelectedState}}(ndsts(graph)) + for vars in var_sccs + # TODO: We should have a way to not have the scc code look at unassigned vars. + if length(vars) == 1 && maximal_top_matching[vars[1]] === unassigned + continue + end + + # Now proceed level by level from lowest to highest and tear the graph. + eqs = [maximal_top_matching[var] + for var in vars if maximal_top_matching[var] !== unassigned] + isempty(eqs) && continue + maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs)) + maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars)) + old_level_vars = () + ict = IncrementalCycleTracker( + DiCMOBiGraph{true}(graph, + complete(Matching(ndsts(graph)), nsrcs(graph))), + dir = :in) + + while level >= 0 + to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs) + to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level) + + to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars) + to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level) + + assigned_eqs = Int[] + + if old_level_vars !== () + # Inherit constraints from previous level. + # TODO: Is this actually a good idea or do we want full freedom + # to tear differently on each level? Does it make a difference + # whether we're using heuristic or optimal tearing? + removed_eqs = Int[] + removed_vars = Int[] + for var in old_level_vars + old_assign = var_eq_matching[var] + if isa(old_assign, SelectedState) + push!(removed_vars, var) + continue + elseif !isa(old_assign, Int) || + ict.graph.matching[var_to_diff[var]] !== unassigned + continue + end + # Make sure the ict knows about this edge, so it doesn't accidentally introduce + # a cycle. + assgned_eq = eq_to_diff[old_assign] + ok = try_assign_eq!(ict, var_to_diff[var], assgned_eq) + @assert ok + var_eq_matching[var_to_diff[var]] = assgned_eq + push!(removed_eqs, eq_to_diff[ict.graph.matching[var]]) + push!(removed_vars, var_to_diff[var]) + push!(removed_vars, var) + end + to_tear_eqs = setdiff(to_tear_eqs, removed_eqs) + to_tear_vars = setdiff(to_tear_vars, removed_vars) + end + tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars), + nothing) + + for var in to_tear_vars + @assert var_eq_matching[var] === unassigned + assgned_eq = ict.graph.matching[var] + var_eq_matching[var] = assgned_eq + isa(assgned_eq, Int) && push!(assigned_eqs, assgned_eq) + end + + if level != 0 + remaining_vars = collect(v + for v in to_tear_vars + if var_eq_matching[v] === unassigned) + if !isempty(remaining_vars) + remaining_eqs = setdiff(to_tear_eqs, assigned_eqs) + nlsolve_matching = maximal_matching(graph, + Base.Fix2(in, remaining_eqs), + Base.Fix2(in, remaining_vars)) + for var in remaining_vars + if nlsolve_matching[var] === unassigned && + var_eq_matching[var] === unassigned + var_eq_matching[var] = SelectedState() + end + end + end + end + + old_level_vars = to_tear_vars + level -= 1 + end + end + return complete(var_eq_matching, nsrcs(graph)) +end + +function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching) + @unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure + eq_to_diff = complete(eq_to_diff) + + inv_eqlevel = map(1:nsrcs(graph)) do eq + level = 0 + while invview(eq_to_diff)[eq] !== nothing + eq = invview(eq_to_diff)[eq] + level += 1 + end + level + end + + varlevel = map(1:ndsts(graph)) do var + graph_level = level = 0 + while var_to_diff[var] !== nothing + var = var_to_diff[var] + level += 1 + if !isempty(𝑑neighbors(graph, var)) + graph_level = level + end + end + graph_level + end + + inv_varlevel = map(1:ndsts(graph)) do var + level = 0 + while invview(var_to_diff)[var] !== nothing + var = invview(var_to_diff)[var] + level += 1 + end + level + end + + var_eq_matching = pss_graph_modia!(structure, + complete(var_eq_matching), varlevel, inv_varlevel, + inv_eqlevel) + + var_eq_matching +end + +function dummy_derivative_graph!(state::TransformationState, jac = nothing; + state_priority = nothing, log = Val(false), kwargs...) + complete!(state.structure) + var_eq_matching = complete(pantelides!(state)) + dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log) +end + +function dummy_derivative_graph!( + structure::SystemStructure, var_eq_matching, jac = nothing, + state_priority = nothing, ::Val{log} = Val(false)) where {log} + @unpack eq_to_diff, var_to_diff, graph = structure + diff_to_eq = invview(eq_to_diff) + diff_to_var = invview(var_to_diff) + invgraph = invview(graph) + + var_sccs = find_var_sccs(graph, var_eq_matching) + eqcolor = falses(nsrcs(graph)) + dummy_derivatives = Int[] + col_order = Int[] + nvars = ndsts(graph) + eqs = Int[] + next_eq_idxs = Int[] + next_var_idxs = Int[] + new_eqs = Int[] + new_vars = Int[] + eqs_set = BitSet() + for vars in var_sccs + empty!(eqs) + for var in vars + eq = var_eq_matching[var] + eq isa Int || continue + diff_to_eq[eq] === nothing && continue + push!(eqs, eq) + end + isempty(eqs) && continue + + rank_matching = Matching(nvars) + isfirst = true + if jac === nothing + J = nothing + else + _J = jac(eqs, vars) + # only accept small integers to avoid overflow + is_all_small_int = all(_J) do x′ + x = unwrap(x′) + x isa Number || return false + isinteger(x) && typemin(Int8) <= x <= typemax(Int8) + end + J = is_all_small_int ? Int.(unwrap.(_J)) : nothing + end + while true + nrows = length(eqs) + iszero(nrows) && break + + if state_priority !== nothing && isfirst + sort!(vars, by = state_priority) + end + # TODO: making the algorithm more robust + # 1. If the Jacobian is a integer matrix, use Bareiss to check + # linear independence. (done) + # + # 2. If the Jacobian is a single row, generate pivots. (Dynamic + # state selection.) + # + # 3. If the Jacobian is a polynomial matrix, use Gröbner basis (?) + if J !== nothing + if !isfirst + J = J[next_eq_idxs, next_var_idxs] + end + N = ModelingToolkit.nullspace(J; col_order) # modifies col_order + rank = length(col_order) - size(N, 2) + for i in 1:rank + push!(dummy_derivatives, vars[col_order[i]]) + end + else + empty!(eqs_set) + union!(eqs_set, eqs) + rank = 0 + for var in vars + eqcolor .= false + # We need `invgraph` here because we are matching from + # variables to equations. + pathfound = construct_augmenting_path!(rank_matching, invgraph, var, + Base.Fix2(in, eqs_set), eqcolor) + pathfound || continue + push!(dummy_derivatives, var) + rank += 1 + rank == nrows && break + end + fill!(rank_matching, unassigned) + end + if rank != nrows + @warn "The DAE system is singular!" + end + + # prepare the next iteration + if J !== nothing + empty!(next_eq_idxs) + empty!(next_var_idxs) + end + empty!(new_eqs) + empty!(new_vars) + for (i, eq) in enumerate(eqs) + ∫eq = diff_to_eq[eq] + # descend by one diff level, but the next iteration of equations + # must still be differentiated + ∫eq === nothing && continue + ∫∫eq = diff_to_eq[∫eq] + ∫∫eq === nothing && continue + if J !== nothing + push!(next_eq_idxs, i) + end + push!(new_eqs, ∫eq) + end + for (i, var) in enumerate(vars) + ∫var = diff_to_var[var] + ∫var === nothing && continue + if J !== nothing + push!(next_var_idxs, i) + end + push!(new_vars, ∫var) + end + eqs, new_eqs = new_eqs, eqs + vars, new_vars = new_vars, vars + isfirst = false + end + end + + if (n_diff_eqs = count(!isnothing, diff_to_eq)) != + (n_dummys = length(dummy_derivatives)) + @warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)." + end + + ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives)) + if log + ret + else + ret[1] + end +end + +function is_present(structure, v)::Bool + @unpack var_to_diff, graph = structure + while true + # if a higher derivative is present, then it's present + isempty(𝑑neighbors(graph, v)) || return true + v = var_to_diff[v] + v === nothing && return false + end +end + +# Derivatives that are either in the dummy derivatives set or ended up not +# participating in the system at all are not considered differential +function is_some_diff(structure, dummy_derivatives, v)::Bool + !(v in dummy_derivatives) && is_present(structure, v) +end + +# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with +# actually differentiated variables. +function isdiffed((structure, dummy_derivatives), v)::Bool + @unpack var_to_diff, graph = structure + diff_to_var = invview(var_to_diff) + diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v) +end + +function tearing_with_dummy_derivatives(structure, dummy_derivatives) + @unpack var_to_diff = structure + # We can eliminate variables that are not selected (differential + # variables). Selected unknowns are differentiated variables that are not + # dummy derivatives. + can_eliminate = falses(length(var_to_diff)) + for (v, dv) in enumerate(var_to_diff) + dv = var_to_diff[v] + if dv === nothing || !is_some_diff(structure, dummy_derivatives, dv) + can_eliminate[v] = true + end + end + var_eq_matching, full_var_eq_matching, var_sccs = tear_graph_modia(structure, + Base.Fix1(isdiffed, (structure, dummy_derivatives)), + Union{Unassigned, SelectedState}; + varfilter = Base.Fix1(getindex, can_eliminate)) + for v in 𝑑vertices(structure.graph) + is_present(structure, v) || continue + dv = var_to_diff[v] + (dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)) && continue + var_eq_matching[v] = SelectedState() + end + return var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate +end diff --git a/src/singularity_removal.jl b/src/singularity_removal.jl new file mode 100644 index 0000000..570b549 --- /dev/null +++ b/src/singularity_removal.jl @@ -0,0 +1,276 @@ +using Graphs.Experimental.Traversals +using .BipartiteGraphs: set_neighbors! + +function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true); + callback = _ -> nothing) where {descend} +g = descend ? invview(var_to_diff) : var_to_diff +callback(v) +while (v′ = g[v]) !== nothing + v::Int = v′ + callback(v) + if level !== nothing + descend ? (level -= 1) : (level += 1) + end +end +level === nothing ? v : (v => level) +end + +function alias_eliminate_graph!(state::TransformationState; kwargs...) + mm = linear_subsys_adjmat!(state; kwargs...) + if size(mm, 1) == 0 + return mm # No linear subsystems + end + + @unpack graph, var_to_diff, solvable_graph = state.structure + mm = alias_eliminate_graph!(state, mm) + s = state.structure + for g in (s.graph, s.solvable_graph) + g === nothing && continue + for (ei, e) in enumerate(mm.nzrows) + set_neighbors!(g, e, mm.row_cols[ei]) + end + end + + return mm +end + +# For debug purposes +function aag_bareiss(sys) + state = TearingState(sys) + complete!(state.structure) + mm = linear_subsys_adjmat!(state) + return aag_bareiss!(state.structure.graph, state.structure.var_to_diff, mm) +end + + +""" +$(SIGNATURES) + +Find the first linear variable such that `𝑠neighbors(adj, i)[j]` is true given +the `constraint`. +""" +@inline function find_first_linear_variable(M::SparseMatrixCLIL, + range, + mask, + constraint) + eadj = M.row_cols + @inbounds for i in range + vertices = eadj[i] + if constraint(length(vertices)) + for (j, v) in enumerate(vertices) + if (mask === nothing || mask[v]) + return (CartesianIndex(i, v), M.row_vals[i][j]) + end + end + end + end + return nothing +end + +@inline function find_first_linear_variable(M::AbstractMatrix, + range, + mask, + constraint) + @inbounds for i in range + row = @view M[i, :] + if constraint(count(!iszero, row)) + for (v, val) in enumerate(row) + if mask === nothing || mask[v] + return CartesianIndex(i, v), val + end + end + end + end + return nothing +end + +function find_masked_pivot(variables, M, k) + r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1)) + r !== nothing && return r + r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2)) + r !== nothing && return r + r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true) + return r +end + +count_nonzeros(a::AbstractArray) = count(!iszero, a) + +# N.B.: Ordinarily sparse vectors allow zero stored elements. +# Here we have a guarantee that they won't, so we can make this identification +count_nonzeros(a::CLILVector) = nnz(a) + +# Linear variables are highest order differentiated variables that only appear +# in linear equations with only linear variables. Also, if a variable's any +# derivatives is nonlinear, then all of them are not linear variables. +function find_linear_variables(graph, linear_equations, var_to_diff, irreducibles) + stack = Int[] + linear_variables = falses(length(var_to_diff)) + var_to_lineq = Dict{Int, BitSet}() + mark_not_linear! = let linear_variables = linear_variables, stack = stack, + var_to_lineq = var_to_lineq + + v -> begin + linear_variables[v] = false + push!(stack, v) + while !isempty(stack) + v = pop!(stack) + eqs = get(var_to_lineq, v, nothing) + eqs === nothing && continue + for eq in eqs, v′ in 𝑠neighbors(graph, eq) + if linear_variables[v′] + linear_variables[v′] = false + push!(stack, v′) + end + end + end + end + end + for eq in linear_equations, v in 𝑠neighbors(graph, eq) + linear_variables[v] = true + vlineqs = get!(() -> BitSet(), var_to_lineq, v) + push!(vlineqs, eq) + end + for v in irreducibles + lv = extreme_var(var_to_diff, v) + while true + mark_not_linear!(lv) + lv = var_to_diff[lv] + lv === nothing && break + end + end + + linear_equations_set = BitSet(linear_equations) + for (v, islinear) in enumerate(linear_variables) + islinear || continue + lv = extreme_var(var_to_diff, v) + oldlv = lv + remove = invview(var_to_diff)[v] !== nothing + while !remove + for eq in 𝑑neighbors(graph, lv) + if !(eq in linear_equations_set) + remove = true + end + end + lv = var_to_diff[lv] + lv === nothing && break + end + lv = oldlv + remove && while true + mark_not_linear!(lv) + lv = var_to_diff[lv] + lv === nothing && break + end + end + + return linear_variables +end + +function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} + @unpack graph, var_to_diff = structure + mm = copy(mm_orig) + linear_equations_set = BitSet(mm_orig.nzrows) + + # All unassigned (not a pivot) algebraic variables that only appears in + # linear algebraic equations can be set to 0. + # + # For all the other variables, we can update the original system with + # Bareiss'ed coefficients as Gaussian elimination is nullspace preserving + # and we are only working on linear homogeneous subsystem. + + is_algebraic = let var_to_diff = var_to_diff + v -> var_to_diff[v] === nothing === invview(var_to_diff)[v] + end + is_linear_variables = is_algebraic.(1:length(var_to_diff)) + is_highest_diff = computed_highest_diff_variables(structure) + for i in 𝑠vertices(graph) + # only consider linear algebraic equations + (i in linear_equations_set && all(is_algebraic, 𝑠neighbors(graph, i))) && + continue + for j in 𝑠neighbors(graph, i) + is_linear_variables[j] = false + end + end + solvable_variables = findall(is_linear_variables) + + local bar + try + bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) + catch e + e isa OverflowError || rethrow(e) + mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig) + bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) + end + + return mm, solvable_variables, bar +end + +function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff) + rank1r = Ref{Union{Nothing, Int}}(nothing) + rank2r = Ref{Union{Nothing, Int}}(nothing) + find_pivot = let rank1r = rank1r + (M, k) -> begin + if rank1r[] === nothing + r = find_masked_pivot(is_linear_variables, M, k) + r !== nothing && return r + rank1r[] = k - 1 + end + if rank2r[] === nothing + r = find_masked_pivot(is_highest_diff, M, k) + r !== nothing && return r + rank2r[] = k - 1 + end + # TODO: It would be better to sort the variables by + # derivative order here to enable more elimination + # opportunities. + return find_masked_pivot(nothing, M, k) + end + end + pivots = Int[] + find_and_record_pivot = let pivots = pivots + (M, k) -> begin + r = find_pivot(M, k) + r === nothing && return nothing + push!(pivots, r[1][2]) + return r + end + end + myswaprows! = let Mold = Mold + (M, i, j) -> begin + Mold !== nothing && Base.swaprows!(Mold, i, j) + Base.swaprows!(M, i, j) + end + end + bareiss_ops = ((M, i, j) -> nothing, myswaprows!, + bareiss_update_virtual_colswap_mtk!, bareiss_zero!) + + rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot) + rank2 = something(rank2r[], rank3) + rank1 = something(rank1r[], rank2) + (rank1, rank2, rank3, pivots) +end + +function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL) + @unpack structure = state + @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure + # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear + # subsystem of the system we're interested in. + # + ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(structure, ils) + + ## Step 2: Simplify the system using the Bareiss factorization + rk1vars = BitSet(@view pivots[1:rank1]) + for v in solvable_variables + v in rk1vars && continue + @set! ils.nparentrows += 1 + push!(ils.nzrows, ils.nparentrows) + push!(ils.row_cols, [v]) + push!(ils.row_vals, [convert(eltype(ils), 1)]) + add_vertex!(graph, SRC) + add_vertex!(solvable_graph, SRC) + add_edge!(graph, ils.nparentrows, v) + add_edge!(solvable_graph, ils.nparentrows, v) + add_vertex!(eq_to_diff) + end + + return ils +end diff --git a/src/tearing.jl b/src/tearing.jl new file mode 100644 index 0000000..54b1ec6 --- /dev/null +++ b/src/tearing.jl @@ -0,0 +1,51 @@ +struct EquationSolveError + eq::Any + var::Any + rhs::Any +end + +function Base.showerror(io::IO, ese::EquationSolveError) + print(io, "EquationSolveError: While solving\n\n\t") + print(io, ese.eq) + print(io, "\nfor ") + printstyled(io, var, bold = true) + print(io, ", obtained RHS\n\n\tt") + println(io, rhs) +end + +function masked_cumsum!(A::Vector) + acc = zero(eltype(A)) + for i in eachindex(A) + iszero(A[i]) && continue + A[i] = (acc += A[i]) + end +end + +function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching, + var_rename, eq_rename, nelim_eq, nelim_var) + dig = DiCMOBiGraph{true}(graph, var_eq_matching) + + # Update bipartite graph + var_deps = map(1:ndsts(graph)) do v + [var_rename[v′] + for v′ in neighborhood(dig, v, Inf; dir = :in) if var_rename[v′] != 0] + end + + newgraph = BipartiteGraph(nsrcs(graph) - nelim_eq, ndsts(graph) - nelim_var) + for e in 𝑠vertices(graph) + ne = eq_rename[e] + ne == 0 && continue + for v in 𝑠neighbors(graph, e) + newvar = var_rename[v] + if newvar != 0 + add_edge!(newgraph, ne, newvar) + else + for nv in var_deps[v] + add_edge!(newgraph, ne, nv) + end + end + end + end + + return newgraph +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..80684c8 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,166 @@ +### +### Bipartite graph utilities +### +using .BipartiteGraphs: 𝑠vertices, 𝑠neighbors + +n_concrete_eqs(state::TransformationState) = n_concrete_eqs(state.structure) +function n_concrete_eqs(graph::BipartiteGraph) + neqs = count(e -> !isempty(𝑠neighbors(graph, e)), 𝑠vertices(graph)) +end + +function error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs) + io = IOBuffer() + neqs = n_concrete_eqs(state) + if iseqs + error_title = "More equations than variables, here are the potential extra equation(s):\n" + out_arr = has_equations(state) ? equations(state)[bad_idxs] : bad_idxs + else + error_title = "More variables than equations, here are the potential extra variable(s):\n" + out_arr = get_fullvars(state)[bad_idxs] + unset_inputs = intersect(out_arr, orig_inputs) + n_missing_eqs = n_highest_vars - neqs + n_unset_inputs = length(unset_inputs) + if n_unset_inputs > 0 + println(io, "In particular, the unset input(s) are:") + Base.print_array(io, unset_inputs) + println(io) + println(io, "The rest of potentially unset variable(s) are:") + end + end + + Base.print_array(io, out_arr) + msg = String(take!(io)) + if iseqs + throw(ExtraEquationsSystemException("The system is unbalanced. There are " * + "$n_highest_vars highest order derivative variables " + * "and $neqs equations.\n" + * error_title + * msg)) + else + throw(ExtraVariablesSystemException("The system is unbalanced. There are " * + "$n_highest_vars highest order derivative variables " + * "and $neqs equations.\n" + * error_title + * msg)) + end +end + +### +### Structural check +### +function check_consistency(state::TransformationState, orig_inputs) + fullvars = get_fullvars(state) + neqs = n_concrete_eqs(state) + @unpack graph, var_to_diff = state.structure + highest_vars = computed_highest_diff_variables(complete!(state.structure)) + n_highest_vars = 0 + for (v, h) in enumerate(highest_vars) + h || continue + isempty(𝑑neighbors(graph, v)) && continue + n_highest_vars += 1 + end + is_balanced = n_highest_vars == neqs + + if neqs > 0 && !is_balanced + varwhitelist = var_to_diff .== nothing + var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned + # Just use `error_reporting` to do conditional + iseqs = n_highest_vars < neqs + if iseqs + eq_var_matching = invview(complete(var_eq_matching, nsrcs(graph))) # extra equations + bad_idxs = findall(isequal(unassigned), @view eq_var_matching[1:nsrcs(graph)]) + else + bad_idxs = findall(isequal(unassigned), var_eq_matching) + end + error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs) + end + + # This is defined to check if Pantelides algorithm terminates. For more + # details, check the equation (15) of the original paper. + extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist; + map(collect, edges(var_to_diff))]) + extended_var_eq_matching = maximal_matching(extended_graph) + + unassigned_var = [] + for (vj, eq) in enumerate(extended_var_eq_matching) + if eq === unassigned && !isempty(𝑑neighbors(graph, vj)) + push!(unassigned_var, fullvars[vj]) + end + end + + if !isempty(unassigned_var) || !is_balanced + io = IOBuffer() + Base.print_array(io, unassigned_var) + unassigned_var_str = String(take!(io)) + errmsg = "The system is structurally singular! " * + "Here are the problematic variables: \n" * + unassigned_var_str + throw(InvalidSystemException(errmsg)) + end + + return nothing +end + +### +### BLT ordering +### + +""" + find_var_sccs(g::BipartiteGraph, assign=nothing) + +Find strongly connected components of the variables defined by `g`. `assign` +gives the undirected bipartite graph a direction. When `assign === nothing`, we +assume that the ``i``-th variable is assigned to the ``i``-th equation. +""" +function find_var_sccs(g::BipartiteGraph, assign = nothing) + cmog = DiCMOBiGraph{true}(g, + Matching(assign === nothing ? Base.OneTo(nsrcs(g)) : assign)) + sccs = Graphs.strongly_connected_components(cmog) + foreach(sort!, sccs) + return sccs +end + +function sorted_incidence_matrix(ts::TransformationState, val = true; only_algeqs = false, + only_algvars = false) + var_eq_matching, var_scc = algebraic_variables_scc(ts) + s = ts.structure + graph = ts.structure.graph + varsmap = zeros(Int, ndsts(graph)) + eqsmap = zeros(Int, nsrcs(graph)) + varidx = 0 + eqidx = 0 + for vs in var_scc, v in vs + eq = var_eq_matching[v] + if eq !== unassigned + eqsmap[eq] = (eqidx += 1) + varsmap[v] = (varidx += 1) + end + end + for i in diffvars_range(s) + varsmap[i] = (varidx += 1) + end + for i in dervars_range(s) + varsmap[i] = (varidx += 1) + end + for i in 1:nsrcs(graph) + if eqsmap[i] == 0 + eqsmap[i] = (eqidx += 1) + end + end + + I = Int[] + J = Int[] + algeqs_set = algeqs(s) + for eq in 𝑠vertices(graph) + only_algeqs && (eq in algeqs_set || continue) + for var in 𝑠neighbors(graph, eq) + only_algvars && (isalgvar(s, var) || continue) + i = eqsmap[eq] + j = varsmap[var] + (iszero(i) || iszero(j)) && continue + push!(I, i) + push!(J, j) + end + end + sparse(I, J, val, nsrcs(graph), ndsts(graph)) +end