From 7e3de9879de54292ac4a0ec8f41375104303120b Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Mar 2024 11:15:10 +0000 Subject: [PATCH] refactor: make batch to be true by default for NNODE --- src/ode_solve.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index ef57beabe..8af6a708d 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -27,11 +27,9 @@ of the physics-informed neural network which is used as a solver for a standard the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size for the loss computation. Defaults to `false`, which - means the application of the neural network is done at individual time points one - at a time. `true` means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. - This requires a neural network compatible with batched data. +* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values + `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + `false` means which means the application of the neural network is done at individual time points one at a time. This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. @@ -89,7 +87,7 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...) + autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end