Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor committed Nov 10, 2022
1 parent e231600 commit 430baff
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 31 deletions.
25 changes: 17 additions & 8 deletions src/piecewise_MLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,24 @@ function _piecewise_MLE(;p_init,
kwargs...
)
dim_prob = get_dims(model) #used by loss_nm
idx_ranges = (1:length(ranges),) # idx of batches

@assert (length(optimizers) == length(epochs) == length(batchsizes)) "`optimizers`, `epochs`, `batchsizes` must be of same length"

@assert (size(data_set,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`."
for (i,opt) in enumerate(optimizers)
OPT = typeof(opt)
if OPT <: Union{Optim.AbstractOptimizer, Optim.Fminbox, Optim.SAMIN, Optim.ConstrainedOptimizer}
@assert batchsizes[i] == length(idx_ranges...) "$OPT is not compatible with mini-batches - use `batchsizes = group_nb`"
end
end

# initialise p_init
p_init = _init_p(p_init, model)
# initialise u0s
u0s_init = _init_u0s(u0s_init, model)
u0s_init = _init_u0s(u0s_init, data_set, ranges, model)
# trainable parameters
θ = [u0s_init;p_init]
# idx of batches
idx_ranges = (1:length(ranges),)

# piecewise loss
function _loss(θ, idx_rngs)
return piecewise_loss(θ,
Expand Down Expand Up @@ -437,11 +445,11 @@ function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) wh
Optim.SAMIN,
Optim.ConstrainedOptimizer}
@info "Running optimizer $OPT"
@assert batchsizes == length(idx_ranges...) "$OPT is not compatible with mini-batches - use `batchsizes = group_nb`"
res = Optimization.solve(optprob,
opt,
maxiters = epochs,
callback = callback)
return res
end

function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) where OPT
Expand All @@ -452,19 +460,20 @@ function __solve(opt::OPT, optprob, idx_ranges, batchsizes, epochs, callback) wh
ncycle(train_loader, epochs),
callback=callback,
save_best=true)
return res
end

function _init_p(p_init, model)
p_init, _ = Optimisers.destructure(p_init)
p_init = get_p_bijector(model)(p_init) # projecting p_init in optimization space
return p_init
end

function _init_u0s(u0s_init, model)
function _init_u0s(u0s_init, data_set, ranges, model)
# initialising with data_set if not provided
if isnothing(u0s_init)
@assert (size(data_set,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`."
u0s_init = reshape(data_set[:,first.(ranges),:],:)
end
u0s_init = [get_u0s_bijector(model)(u0) for u0 in u0s_init] # projecting u0s_init in optimization space
u0s_init = [get_u0_bijector(model)(u0) for u0 in u0s_init] # projecting u0s_init in optimization space
return u0s_init
end
2 changes: 1 addition & 1 deletion src/piecewise_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ end

function _get_u0s(θ, nb_group, dim_prob, model)
# converting back to u0 space
u0_bij⁻¹ = inverse(get_u0s_bij(model))
u0_bij⁻¹ = inverse(get_u0_bijector(model))
return [u0_bij⁻¹(θ[dim_prob*(i-1)+1:dim_prob*i]) for i in 1:nb_group]
end
15 changes: 8 additions & 7 deletions test/InferenceResult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ tspan = (tsteps[1], tsteps[end])
p_true = (b = [0.23, 0.5],)
p_init= (b = [1., 2.],)

dists = (Uniform(0.,5.),)
u0s_bij = bijector(Uniform(0.,5.))
p_bij = (bijector(Uniform(0.,3.)),)
# u0_bij = bijector(Uniform(0.,5.))
u0_bij = bijector(Uniform(0.,5.))

u0 = ones(2)
mp = ModelParams(p_true,
dists,
mp = ModelParams(;p = p_true,
p_bij,
tspan,
u0,
u0s_bij,
BS3(),
sensealg = ForwardDiffSensitivity();
u0_bij,
alg = BS3(),
sensealg = ForwardDiffSensitivity(),
saveat = tsteps,
)
model = MyModel(mp)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand Down
13 changes: 8 additions & 5 deletions test/piecewise_MLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ p_true = (b = [0.23, 0.5],)
p_init= (b = [1., 2.],)

u0 = ones(2)
dist = (bijector(Uniform(1e-3, 5e0)),)
mp = ModelParams(p_true,
dist,
p_bij = (bijector(Uniform(1e-3, 5e0)),)
u0_bij = bijector(Uniform(1e-3,5.))

mp = ModelParams(; p = p_true,
p_bij,
tspan,
u0,
BS3(),
sensealg = ForwardDiffSensitivity();
u0_bij,
alg = BS3(),
sensealg = ForwardDiffSensitivity(),
saveat = tsteps,
)
model = MyModel(mp)
Expand Down
6 changes: 3 additions & 3 deletions test/piecewise_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ p_true = (r = [0.5, 1.], b = [0.23, 0.5],)
p_init= (r = [0.7, 1.2], b = [0.2, 0.2],)

u0 = ones(2)
mp = ModelParams(p_true,
mp = ModelParams(;p = p_true,
tspan,
u0,
BS3(),
sensealg = ForwardDiffSensitivity();
alg = BS3(),
sensealg = ForwardDiffSensitivity(),
saveat = tsteps,
)
model = MyModel(mp)
Expand Down
17 changes: 10 additions & 7 deletions test/statistics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using GLM, UnPack
using LinearAlgebra
using Bijectors: Identity
using Distributions, DataFrames
using Bijectors, Optimisers

@model MyModel
function (m::MyModel)(du, u, p, t)
Expand All @@ -14,17 +14,20 @@ tsteps = range(tspan[1], tspan[end], length=1000)

p_true = (r = [0.5, 1.], b = [0.23, 0.5],)
p_init= (r = [0.7, 1.2], b = [0.2, 0.2],)
dists = (Identity{0}(), Identity{0}())
p_bij = (Identity{0}(), Identity{0}())
u0 = ones(2)
mp = ModelParams(p_true,
dists,
u0_bij = bijector(Uniform(1e-3,5.))

mp = ModelParams(;p = p_true,
p_bij,
u0_bij,
tspan,
u0,
BS3(),
sensealg = ForwardDiffSensitivity();
alg = BS3(),
sensealg = ForwardDiffSensitivity(),
saveat = tsteps,
)
mymodel = MyModel(mp)
model = MyModel(mp)
sol = simulate(mymodel)
true_data = sol |> Array

Expand Down

0 comments on commit 430baff

Please sign in to comment.