Skip to content

Commit

Permalink
refactor: make batch to be true by default for NNODE
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 22, 2024
1 parent 7f3cc74 commit 7e3de98
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ of the physics-informed neural network which is used as a solver for a standard
the PDE operators. The reverse mode of the loss function is always
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `batch`: The batch size for the loss computation. Defaults to `false`, which
means the application of the neural network is done at individual time points one
at a time. `true` means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations.
This requires a neural network compatible with batched data.
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
* `strategy`: The training strategy used to choose the points for the evaluations.
Expand Down Expand Up @@ -89,7 +87,7 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...)
autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs)
end
Expand Down

0 comments on commit 7e3de98

Please sign in to comment.