From 430baffd8be6e9852241831e299e2d5a21ff0d18 Mon Sep 17 00:00:00 2001 From: Victor Date: Thu, 10 Nov 2022 10:22:28 +0100 Subject: [PATCH] tests passing --- src/piecewise_MLE.jl | 25 +++++++++++++++++-------- src/piecewise_loss.jl | 2 +- test/InferenceResult.jl | 15 ++++++++------- test/Project.toml | 1 + test/piecewise_MLE.jl | 13 ++++++++----- test/piecewise_loss.jl | 6 +++--- test/statistics.jl | 17 ++++++++++------- 7 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/piecewise_MLE.jl b/src/piecewise_MLE.jl index aaf60a0..eba7bc2 100644 --- a/src/piecewise_MLE.jl +++ b/src/piecewise_MLE.jl @@ -201,16 +201,24 @@ function _piecewise_MLE(;p_init, kwargs... ) dim_prob = get_dims(model) #used by loss_nm + idx_ranges = (1:length(ranges),) # idx of batches + @assert (length(optimizers) == length(epochs) == length(batchsizes)) "`optimizers`, `epochs`, `batchsizes` must be of same length" - + @assert (size(data_set,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`." + for (i,opt) in enumerate(optimizers) + OPT = typeof(opt) + if OPT <: Union{Optim.AbstractOptimizer, Optim.Fminbox, Optim.SAMIN, Optim.ConstrainedOptimizer} + @assert batchsizes[i] == length(idx_ranges...) "$OPT is not compatible with mini-batches - use `batchsizes = group_nb`" + end + end + # initialise p_init p_init = _init_p(p_init, model) # initialise u0s - u0s_init = _init_u0s(u0s_init, model) + u0s_init = _init_u0s(u0s_init, data_set, ranges, model) # trainable parameters θ = [u0s_init;p_init] - # idx of batches - idx_ranges = (1:length(ranges),) + # piecewise loss function _loss(θ, idx_rngs) return piecewise_loss(θ, @@ -437,11 +445,11 @@ function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) wh Optim.SAMIN, Optim.ConstrainedOptimizer} @info "Running optimizer $OPT" - @assert batchsizes == length(idx_ranges...) "$OPT is not compatible with mini-batches - use `batchsizes = group_nb`" res = Optimization.solve(optprob, opt, maxiters = epochs, callback = callback) + return res end function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) where OPT @@ -452,19 +460,20 @@ function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) wh ncycle(train_loader, epochs), callback=callback, save_best=true) + return res end function _init_p(p_init, model) p_init, _ = Optimisers.destructure(p_init) p_init = get_p_bijector(model)(p_init) # projecting p_init in optimization space + return p_init end -function _init_u0s(u0s_init, model) +function _init_u0s(u0s_init, data_set, ranges, model) # initialising with data_set if not provided if isnothing(u0s_init) - @assert (size(data_set,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`." u0s_init = reshape(data_set[:,first.(ranges),:],:) end - u0s_init = [get_u0s_bijector(model)(u0) for u0 in u0s_init] # projecting u0s_init in optimization space + u0s_init = [get_u0_bijector(model)(u0) for u0 in u0s_init] # projecting u0s_init in optimization space return u0s_init end \ No newline at end of file diff --git a/src/piecewise_loss.jl b/src/piecewise_loss.jl index 2be159f..596a53d 100644 --- a/src/piecewise_loss.jl +++ b/src/piecewise_loss.jl @@ -113,6 +113,6 @@ end function _get_u0s(θ, nb_group, dim_prob, model) # converting back to u0 space - u0_bij⁻¹ = inverse(get_u0s_bij(model)) + u0_bij⁻¹ = inverse(get_u0_bijector(model)) return [u0_bij⁻¹(θ[dim_prob*(i-1)+1:dim_prob*i]) for i in 1:nb_group] end \ No newline at end of file diff --git a/test/InferenceResult.jl b/test/InferenceResult.jl index d52a2cb..6e83e37 100644 --- a/test/InferenceResult.jl +++ b/test/InferenceResult.jl @@ -18,17 +18,18 @@ tspan = (tsteps[1], tsteps[end]) p_true = (b = [0.23, 0.5],) p_init= (b = [1., 2.],) -dists = (Uniform(0.,5.),) -u0s_bij = bijector(Uniform(0.,5.)) +p_bij = (bijector(Uniform(0.,3.)),) +# u0_bij = bijector(Uniform(0.,5.)) +u0_bij = bijector(Uniform(0.,5.)) u0 = ones(2) -mp = ModelParams(p_true, - dists, +mp = ModelParams(;p = p_true, + p_bij, tspan, u0, - u0s_bij, - BS3(), - sensealg = ForwardDiffSensitivity(); + u0_bij, + alg = BS3(), + sensealg = ForwardDiffSensitivity(), saveat = tsteps, ) model = MyModel(mp) diff --git a/test/Project.toml b/test/Project.toml index a442ab3..e68319d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" diff --git a/test/piecewise_MLE.jl b/test/piecewise_MLE.jl index 2eb400b..c1fd46f 100644 --- a/test/piecewise_MLE.jl +++ b/test/piecewise_MLE.jl @@ -19,13 +19,16 @@ p_true = (b = [0.23, 0.5],) p_init= (b = [1., 2.],) u0 = ones(2) -dist = (bijector(Uniform(1e-3, 5e0)),) -mp = ModelParams(p_true, - dist, +p_bij = (bijector(Uniform(1e-3, 5e0)),) +u0_bij = bijector(Uniform(1e-3,5.)) + +mp = ModelParams(; p = p_true, + p_bij, tspan, u0, - BS3(), - sensealg = ForwardDiffSensitivity(); + u0_bij, + alg = BS3(), + sensealg = ForwardDiffSensitivity(), saveat = tsteps, ) model = MyModel(mp) diff --git a/test/piecewise_loss.jl b/test/piecewise_loss.jl index 2198e6a..b5ba877 100644 --- a/test/piecewise_loss.jl +++ b/test/piecewise_loss.jl @@ -14,11 +14,11 @@ p_true = (r = [0.5, 1.], b = [0.23, 0.5],) p_init= (r = [0.7, 1.2], b = [0.2, 0.2],) u0 = ones(2) -mp = ModelParams(p_true, +mp = ModelParams(;p = p_true, tspan, u0, - BS3(), - sensealg = ForwardDiffSensitivity(); + alg = BS3(), + sensealg = ForwardDiffSensitivity(), saveat = tsteps, ) model = MyModel(mp) diff --git a/test/statistics.jl b/test/statistics.jl index 42dc8bd..64fe2fd 100644 --- a/test/statistics.jl +++ b/test/statistics.jl @@ -1,7 +1,7 @@ using GLM, UnPack using LinearAlgebra -using Bijectors: Identity using Distributions, DataFrames +using Bijectors, Optimisers @model MyModel function (m::MyModel)(du, u, p, t) @@ -14,17 +14,20 @@ tsteps = range(tspan[1], tspan[end], length=1000) p_true = (r = [0.5, 1.], b = [0.23, 0.5],) p_init= (r = [0.7, 1.2], b = [0.2, 0.2],) -dists = (Identity{0}(), Identity{0}()) +p_bij = (Identity{0}(), Identity{0}()) u0 = ones(2) -mp = ModelParams(p_true, - dists, +u0_bij = bijector(Uniform(1e-3,5.)) + +mp = ModelParams(;p = p_true, + p_bij, + u0_bij, tspan, u0, - BS3(), - sensealg = ForwardDiffSensitivity(); + alg = BS3(), + sensealg = ForwardDiffSensitivity(), saveat = tsteps, ) -mymodel = MyModel(mp) +model = MyModel(mp) sol = simulate(mymodel) true_data = sol |> Array