Skip to content

Commit

Permalink
merged ParametricModels.jl and PiecewiseInference.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
vboussange committed Jun 12, 2024
1 parent ab05c8d commit 6de07e9
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PiecewiseInference"
uuid = "27a201f4-b6a1-4745-b96e-0c27845dca54"
authors = ["Victor <bvictor@ethz.ch>"]
version = "0.9.8"
version = "0.10.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand All @@ -19,7 +19,6 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ParametricModels = "ea05b012-1b06-4aec-a786-2a545e229cd0"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand All @@ -31,3 +30,4 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Bijectors = "≥ 0.12.0"
SciMLBase = "≥ 1.82.0"
julia = "≥ 1.8"
# Optimization = "3.19"
1 change: 0 additions & 1 deletion benchmark/benchmark_multithreading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Benchmarking threads vs non threads
=#
using PiecewiseInference
using ParametricModels
using SciMLSensitivity
using OrdinaryDiffEq
using ComponentArrays
Expand Down
24 changes: 12 additions & 12 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
[![Build status (Github Actions)](https://github.com/vboussange/PiecewiseInference.jl/workflows/CI/badge.svg)](https://github.com/vboussange/PiecewiseInference.jl/actions)
[![codecov.io](http://codecov.io/github/vboussange/PiecewiseInference.jl/coverage.svg?branch=main)](http://codecov.io/github/vboussange/PiecewiseInference.jl?branch=main)

**PiecewiseInference.jl** is designed to enhance the convergence of time series-based parameter inversion methods. It achieves this by implementing a **segmentation method** and **parameter normalization**, which regularize the inference problem, together with **minibatching**.
**PiecewiseInference.jl** is a library to enhance the convergence of dynamical model parameter inversion methods. It provides features such as
- a segmentation strategy,
- the independent estimation of initial conditions for each segment,
- parameter transformation,
- parameter and initial conditions regularization
- mini-batching

Taken altogether, these features regularize the inference problem and permit to solve it efficiently.

![](docs/animated.gif)

## Installation
PiecewiseInference.jl has [ParametricModels.jl](https://github.com/vboussange/ParametricModels.jl) in its dependency, a non-registered package. As such, to install PiecewiseInference.jl, you'll need to first add an extra registry to your Julia installation that tracks both ParametricModels.jl and PiecewiseInference.jl.

To proceed, open Julia and type the following
```julia
using Pkg
pkg"registry add https://github.com/vboussange/VBoussangeRegistry.git"
```
Then go on and
Open Julia REPL and type
```julia
pkg"add PiecewiseInference"
using Pkg; Pkg.add(url="https://github.com/vboussange/PiecewiseInference.jl")
```

That's it! This will download the latest version of **PiecewiseInference.jl** from this git repo and download all dependencies.
Expand All @@ -27,10 +27,10 @@ That's it! This will download the latest version of **PiecewiseInference.jl** fr
## Getting started

Check out [this blog post](https://vboussange.github.io/post/piecewiseinference/) providing a hands-on tutorial.
See also the documentation and the `test` folder.
See also the API documentation and the `test` folder.

## Related packages
`DiffEqFlux` is a package with similar goals as `PiecewiseInference`, and proposes the method `DiffEqFlux.multiple_shooting`, which is close to `PiecewiseInference.inference` but where initial conditions are not inferred. `PiecewiseInference` further proposes several utility methods for model selection.

## Reference
- Boussange, V., Vilimelis-Aceituno, P., Pellissier, L., _Mini-batching ecological data to improve ecosystem models with machine learning_ [[bioRxiv](https://www.biorxiv.org/content/10.1101/2022.07.25.501365v1)] (2022), 46 pages.
Boussange, V., Vilimelis-Aceituno, P., Schäfer, F., Pellissier, L., _Partitioning time series to improve process-based models with machine learning_. [[bioRxiv]](https://www.biorxiv.org/content/10.1101/2022.07.25.501365v2) (2024), 46 pages.
4 changes: 1 addition & 3 deletions src/InferenceProblem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Base.@kwdef struct InferenceProblem{M <: AbstractModel,P,PP,U0P,LL,PB,UB}
m::M # model inheriting ParametricModels.AbstractModel
m::M # model inheriting AbstractModel
p0::P # parameter ComponentArray
loss_param_prior::PP # loss used to directly constrain parameters
loss_u0_prior::U0P # loss used to directly constrain ICs
Expand Down Expand Up @@ -65,14 +65,12 @@ function InferenceProblem(model::M,
u0_bij)
end

import ParametricModels: get_p, get_mp, get_tspan
get_p(prob::InferenceProblem) = prob.p0
get_p_bijector(prob::InferenceProblem) =prob.p_bij
get_u0_bijector(prob::InferenceProblem) = prob.u0_bij
get_tspan(prob::InferenceProblem) = get_tspan(prob.m)
get_model(prob::InferenceProblem) = prob.m
get_mp(prob::InferenceProblem) = get_mp(get_model(prob))
import ParametricModels.get_dims
get_dims(prob::InferenceProblem) = get_dims(get_model(prob))
get_loss_likelihood(prob::InferenceProblem) = prob.loss_likelihood
get_loss_param_prior(prob::InferenceProblem) = prob.loss_param_prior
Expand Down
5 changes: 3 additions & 2 deletions src/PiecewiseInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ __precompile__()
$(DocStringExtensions.README)
"""
module PiecewiseInference
using ParametricModels
using OrdinaryDiffEq
using Optimization
using OptimizationOptimJL:Optim
Expand All @@ -18,7 +17,6 @@ module PiecewiseInference
using Distributions
import Distributions:loglikelihood #overwritten

using ParametricModels
using Optimisers, Flux
using IterTools: ncycle
using Bijectors
Expand All @@ -31,6 +29,7 @@ module PiecewiseInference
import Base.length
length(::ParamFun{N}) where N = N

include("models.jl")
include("InferenceProblem.jl")
include("InferenceResult.jl")
include("utils.jl")
Expand All @@ -46,6 +45,8 @@ module PiecewiseInference
@require PyPlot="d330b81b-6aea-500a-939a-2ce795aea3ee" include("plot_convergence.jl")
end

export AbstractModel, ComposableModel, simulate, ModelParams, @model, name, remake
export get_p, get_u0, get_alg, get_tspan, get_kwargs, get_mp, get_dims, get_prob
export InferenceProblem, get_p, get_p_bijector, get_u0_bijector, get_re, get_tspan, get_model, get_mp
export ParamFun, InferenceResult, get_p_trained, forecast
export group_ranges, AIC, AICc, AICc_TREE, moments!, moments, divisors
Expand Down
12 changes: 6 additions & 6 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ $(SIGNATURES) performs piecewise inference for a given `InferenceProblem` and
segment. If not provided, initial guesses are initialized from the `data`.
- `optimizers` : array of optimizers, e.g. `[Adam(0.01)]`
- `epochs` : A vector with number of epochs for each optimizer in `optimizers`.
- `batchsizes`: An vector of batch sizes, which should match the length of
- `batchsizes`: A vector of batch sizes, which should match the length of
`optimizers`. If nothing is provided, all segments are used at once (full
batch).
- `verbose_loss` : Whether to display loss during training.
Expand All @@ -115,7 +115,6 @@ using SciMLSensitivity # provides diffential equation sensitivity methods
using UnPack # provides the utility macro @unpack
using OptimizationOptimisers, OptimizationFlux # provide the optimizers
using LinearAlgebra
using ParametricModels
using PiecewiseInference
using OrdinaryDiffEq
using Distributions, Bijectors # used to constrain parameters and initial conditions
Expand All @@ -139,7 +138,7 @@ mp = ModelParams(;p = p_true,
saveat = tsteps,
)
model = MyModel(mp)
sol_data = ParametricModels.simulate(model)
sol_data = simulate(model)
ode_data = Array(sol_data)
# adding some normally distributed noise
σ_noise = 0.1
Expand Down Expand Up @@ -221,6 +220,7 @@ function inference(infprob;
u0s_init = _init_u0s(data, ranges)
end
# build θ, which is the parameter vector containing u0s, in the parameter space

θ = _build_θ(p0, u0s_init, infprob)

# piecewise loss
Expand All @@ -239,10 +239,10 @@ function inference(infprob;
println("inference with $(length(tsteps)) points and $nb_group groups.")

# Here we need a default behavior for Optimization.jl (see https://github.com/SciML/Optimization.jl/blob/c0a51120c7c54a89d091b599df30eb40c4c0952b/lib/OptimizationFlux/src/OptimizationFlux.jl#L53)
callback(θ, l, pred=[]) = begin
callback(state, l, pred=[]) = begin
push!(losses, l)
p_trained = to_param_space(θ, infprob)

p_trained = to_param_space(state.u, infprob)
# TODO: this is called at every parameter update, it would be nice to have it at every epoch
if length(losses)%info_per_its==0
verbose_loss && (println("Loss after $(length(losses)) iterations: $(losses[end])"))
end
Expand Down
182 changes: 182 additions & 0 deletions src/models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import Base
function Base.merge(ca::ComponentArray{T}, ca2::ComponentArray{T2}) where {T, T2}
ax = getaxes(ca)
ax2 = getaxes(ca2)
vks = valkeys(ax[1])
vks2 = valkeys(ax2[1])
_p = Vector{T}()
for vk in vks
@assert !(getproperty(ca2, vk) isa ComponentVector) "Only non-nested `ComponentArray`s are supported by `merge`."
if vk in vks2
_vec = vec(getproperty(ca2, vk)) # ca2[vk]
_p = vcat(_p, _vec)
else
_vec = vec(getproperty(ca, vk)) # ca1[vk]
_p = vcat(_p, _vec)
end
end
# for vk in vks2
# if vk not in vks
# _vec = vec(getproperty(ca2, vk)) # ca1[vk]
# _p = vcat(_p, _vec)
# ax = merge(ax, )
# end
# end
ComponentArray(_p, ax)
end

Base.merge(::Nothing, ca2::ComponentArray{T2}) where {T2} = ca2


# This piece is inspired from https://github.com/jonniedie/ComponentArrays.jl/pull/217
# import ComponentArrays: promote_type, getval, Val, indexmap
# @generated function valkeys(ax::AbstractAxis)
# idxmap = indexmap(ax)
# k = Val.(keys(idxmap))
# return :($k)
# end
# valkeys(ca::ComponentVector) = valkeys(getaxes(ca)[1])

# function merge(cvec1::ComponentVector{T1}, cvec2::ComponentVector{T2}) where {T1, T2}
# typed_dict = ComponentVector{promote_type(T1, T2)}(cvec1)
# for key in valkeys(cvec2)
# keyname = getval(key)
# val = cvec2[key]
# typed_dict = eval(:( ComponentArray($typed_dict, $keyname = $val) ))
# end
# typed_dict
# end

abstract type AbstractModel end
name(m::AbstractModel) = string(nameof(typeof(m)))
Base.show(io::IO, cm::AbstractModel) = println(io, "`Model` ", name(cm))

"""
$(SIGNATURES)
Returns the `ODEProblem` associated with to `m`.
"""
function get_prob(m::AbstractModel, u0, tspan, p)
prob = ODEProblem(m, u0, tspan, p)
return prob
end

"""
$(SIGNATURES)
Simulate model `m` and returns an `ODESolution`.
When provided, keyword arguments overwrite default solving options
in m.
"""
function simulate(m::AbstractModel; u0 = nothing, tspan=nothing, p = nothing, alg = nothing, kwargs...)
isnothing(u0) ? u0 = get_u0(m) : nothing
isnothing(tspan) ? tspan = get_tspan(m) : nothing
if isnothing(p)
p = get_p(m)
else
# p can be a sub tuple of the full parameter tuple
p0 = get_p(m)
p = merge(p0, p)
end
isnothing(alg) ? alg = get_alg(m) : nothing
prob = get_prob(m, u0, tspan, p)
# kwargs erases get_kwargs(m)
sol = solve(prob, alg; get_kwargs(m)..., kwargs...)
return sol
end

struct ModelParams{P,T,U0,A,K}
p::P # model parameters; we require dictionary or named tuples or componentarrays
tspan::T # time span
u0::U0 # initial conditions
alg::A # alg for ODEsolve
kwargs::K # kwargs given to solve fn, e.g., saveat
end

import SciMLBase.remake
function remake(mp::ModelParams;
p = mp.p,
tspan = mp.tspan,
u0 = mp.u0,
alg = mp.alg,
kwargs = mp.kwargs)
ModelParams(p, tspan, u0, alg, kwargs)
end

# # for the remake fn
# function ModelParams(;p,
# p_bij::PST,
# re,
# tspan,
# u0,
# u0_bij,
# alg,
# dims,
# plength,
# kwargs) where PST <: Bijector
# ModelParams(p,
# p_bij,
# re,
# tspan,
# u0,
# u0_bij,
# alg,
# dims,
# plength,
# kwargs)
# end

# model parameters
"""
$(SIGNATURES)
Structure containing the details for the numerical simulation of a model.
# Arguments
- `tspan`: time span of the simulation
- `u0`: initial condition of the simulation
- `alg`: numerical solver
- `kwargs`: extra keyword args provided to the `solve` function.
# Optional
- `p`: default parameter values
# Example
mp = ModelParams()
"""
function ModelParams(; p = nothing, tspan = nothing, u0 = nothing, alg = nothing, kwargs...)
ModelParams(p,
tspan,
u0,
alg,
kwargs)
end
ModelParams(p, tspan, u0, alg) = ModelParams(p, tspan, u0, alg, ())

get_mp(m::AbstractModel) = m.mp
get_p(m::AbstractModel) = m.mp.p
get_u0(m::AbstractModel) = m.mp.u0
get_alg(m::AbstractModel) = m.mp.alg
get_tspan(m::AbstractModel) = m.mp.tspan
get_kwargs(m::AbstractModel) = m.mp.kwargs
"""
$SIGNATURES
Returns the dimension of the state variable
"""
get_dims(m::AbstractModel) = length(get_u0(m))

"""
$SIGNATURES
Generates the skeleton of the model, a `struct` containing details of the numerical implementation.
"""
macro model(name)
expr = quote
struct $name{MP<:ModelParams} <: AbstractModel
mp::MP
end

$(esc(name))(;mp) = $(esc(name))(mp)
end
return expr
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ function to_optim_space(p::ComponentArray, infprob::InferenceProblem)
end

# TODO /!\ order is not guaranteed!
function to_param_space::ComponentArray, infprob::InferenceProblem)
function to_param_space(θ, infprob::InferenceProblem)
@unpack p0, p_bij = infprob
pairs = [reshape(inverse(p_bij[k])(getproperty(θ,k)),:) for k in keys(p0)]
ax = getaxes(p0)
Expand Down
2 changes: 1 addition & 1 deletion test/InferenceResult.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using PiecewiseInference
using LinearAlgebra, ParametricModels, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, OrdinaryDiffEq, SciMLSensitivity
using ComponentArrays
using UnPack
using OptimizationOptimisers, OptimizationOptimJL
Expand Down
2 changes: 1 addition & 1 deletion test/evidence.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra, ParametricModels, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, OrdinaryDiffEq, SciMLSensitivity
using UnPack
using OptimizationOptimisers
using Test
Expand Down
2 changes: 1 addition & 1 deletion test/inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra, ParametricModels, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, OrdinaryDiffEq, SciMLSensitivity
using UnPack
using OptimizationOptimisers, OptimizationFlux, OptimizationOptimJL
using Test
Expand Down
Loading

0 comments on commit 6de07e9

Please sign in to comment.