From 8ae7b0f375e90410d05f8050109336e877c867a3 Mon Sep 17 00:00:00 2001 From: Kirill Zubov Date: Wed, 17 Jul 2024 16:51:43 +0400 Subject: [PATCH] Update src/pino_ode_solve.jl Co-authored-by: Sathvik Bhagavan <35105271+sathvikbhagavan@users.noreply.github.com> --- src/pino_ode_solve.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index e20e1dec2..fdfbeaf30 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -67,11 +67,8 @@ end function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - if init_params === nothing - init_params = ComponentArrays.ComponentArray(θ) - else - init_params = ComponentArrays.ComponentArray(init_params) - end + init_params = isnothing(init_params) ? θ : init_params + init_params = ComponentArrays.ComponentArray(init_params) PINOPhi(chain, st), init_params end