From b9e4667cbef0351c5a1e30d8bed55c4c643ea1f2 Mon Sep 17 00:00:00 2001 From: Victor Boussange Date: Wed, 22 Nov 2023 16:35:57 +0100 Subject: [PATCH] improving documentation --- src/inference.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/inference.jl b/src/inference.jl index 49f6663..cf2e355 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -68,9 +68,8 @@ function piecewise_ML_indep_TS(infprob; end """ -$(SIGNATURES) -performs piecewise inference for a given `InferenceProblem` and `data`. Loops -through the optimizers `optimizers`. Returns a `InferenceResult`. +$(SIGNATURES) performs piecewise inference for a given `InferenceProblem` and +`data`. Loops through the optimizers `optimizers`. Returns a `InferenceResult`. # Arguments - `infprob`: An instance of `InferenceProblem` that defines the model, the parameter constraints and its likelihood function. @@ -91,18 +90,22 @@ through the optimizers `optimizers`. Returns a `InferenceResult`. - `optimizers` : array of optimizers, e.g. `[Adam(0.01)]` - `epochs` : A vector with number of epochs for each optimizer in `optimizers`. - `batchsizes`: An vector of batch sizes, which should match the length of - `optimizers`. If nothing is provided, all segments are used at once (full batch). + `optimizers`. If nothing is provided, all segments are used at once (full + batch). - `verbose_loss` : Whether to display loss during training. - `info_per_its = 50`: The frequency at which to display the training information. - `plotting` : Whether to plot the convergence loss during training. -- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred, ranges)`. +- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred, + ranges)`. - `threshold` : The tolerance for stopping training. - `save_pred = true`: Whether to save the predictions. - `save_losses = true` : Whether to save the losses. - `adtype = Optimization.AutoForwardDiff()` : The automatic differentiation (AD) type to be used. Can be `Optimization.AutoForwardDiff()` for forward AD or `Optimization.Autozygote()` for backward AD. +- `u0s_init = nothing`: if provided, should be a vector of the form `[u0_1, ..., + u0_n]` where `n` is the number of segments - `multi_threading = true`: if `true`, segments in the piecewise loss are computed in parallel. Currently not supported with `adtype = Optimization.Autozygote()` @@ -203,7 +206,7 @@ function inference(infprob; isnothing(batchsizes) && (batchsizes = fill(length(ranges),length(epochs))) @assert (length(optimizers) == length(epochs) == length(batchsizes)) "`optimizers`, `epochs`, `batchsizes` must be of same length" - @assert (size(data,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`." + @assert ((size(data,1) == dim_prob) && isnothing(u0s_init)) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`." for (i,opt) in enumerate(optimizers) OPT = typeof(opt) if OPT <: Union{Optim.AbstractOptimizer, Optim.Fminbox, Optim.SAMIN, Optim.ConstrainedOptimizer} @@ -266,7 +269,7 @@ function inference(infprob; ### TRAINING ### ################ # Container to track the losses - losses = Float64[] + losses = eltype(data)[] @info "Training started" objectivefun = OptimizationFunction(__loss, adtype) # similar to https://sensitivity.sciml.ai/stable/ode_fitting/stiff_ode_fit/