Skip to content

Commit

Permalink
improving documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
vboussange committed Nov 22, 2023
1 parent 2bfb8dc commit b9e4667
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ function piecewise_ML_indep_TS(infprob;
end

"""
$(SIGNATURES)
performs piecewise inference for a given `InferenceProblem` and `data`. Loops
through the optimizers `optimizers`. Returns a `InferenceResult`.
$(SIGNATURES) performs piecewise inference for a given `InferenceProblem` and
`data`. Loops through the optimizers `optimizers`. Returns a `InferenceResult`.
# Arguments
- `infprob`: An instance of `InferenceProblem` that defines the model, the
parameter constraints and its likelihood function.
Expand All @@ -91,18 +90,22 @@ through the optimizers `optimizers`. Returns a `InferenceResult`.
- `optimizers` : array of optimizers, e.g. `[Adam(0.01)]`
- `epochs` : A vector with number of epochs for each optimizer in `optimizers`.
- `batchsizes`: An vector of batch sizes, which should match the length of
`optimizers`. If nothing is provided, all segments are used at once (full batch).
`optimizers`. If nothing is provided, all segments are used at once (full
batch).
- `verbose_loss` : Whether to display loss during training.
- `info_per_its = 50`: The frequency at which to display the training
information.
- `plotting` : Whether to plot the convergence loss during training.
- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred, ranges)`.
- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred,
ranges)`.
- `threshold` : The tolerance for stopping training.
- `save_pred = true`: Whether to save the predictions.
- `save_losses = true` : Whether to save the losses.
- `adtype = Optimization.AutoForwardDiff()` : The automatic differentiation (AD)
type to be used. Can be `Optimization.AutoForwardDiff()` for forward AD or
`Optimization.Autozygote()` for backward AD.
- `u0s_init = nothing`: if provided, should be a vector of the form `[u0_1, ...,
u0_n]` where `n` is the number of segments
- `multi_threading = true`: if `true`, segments in the piecewise loss are
computed in parallel. Currently not supported with `adtype =
Optimization.Autozygote()`
Expand Down Expand Up @@ -203,7 +206,7 @@ function inference(infprob;
isnothing(batchsizes) && (batchsizes = fill(length(ranges),length(epochs)))

@assert (length(optimizers) == length(epochs) == length(batchsizes)) "`optimizers`, `epochs`, `batchsizes` must be of same length"
@assert (size(data,1) == dim_prob) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`."
@assert ((size(data,1) == dim_prob) && isnothing(u0s_init)) "The dimension of the training data does not correspond to the dimension of the state variables. This probably means that the training data corresponds to observables different from the state variables. In this case, you need to provide manually `u0s_init`."
for (i,opt) in enumerate(optimizers)
OPT = typeof(opt)
if OPT <: Union{Optim.AbstractOptimizer, Optim.Fminbox, Optim.SAMIN, Optim.ConstrainedOptimizer}
Expand Down Expand Up @@ -266,7 +269,7 @@ function inference(infprob;
### TRAINING ###
################
# Container to track the losses
losses = Float64[]
losses = eltype(data)[]

@info "Training started"
objectivefun = OptimizationFunction(__loss, adtype) # similar to https://sensitivity.sciml.ai/stable/ode_fitting/stiff_ode_fit/
Expand Down

0 comments on commit b9e4667

Please sign in to comment.