diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index 9769571bf..a530a1d57 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -53,7 +53,7 @@ function PINOODE(chain, PINOODE(chain, opt, bounds, init_params, strategy, additional_loss, kwargs) end -struct PINOPhi{C, S} +mutable struct PINOPhi{C, S} chain::C st::S function PINOPhi(chain::Lux.AbstractExplicitLayer, st) @@ -148,7 +148,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, error("Only Lux.AbstractExplicitLayer neural networks are supported") #TODO implement for u0 - if !any(in(keys(bounds)), (:p,)) + if !any(in(keys(bounds)), (:p, :u0)) error("bounds should contain p only") end phi, init_params = generate_pino_phi_θ(chain, init_params)