Skip to content

Commit

Permalink
refactor: error out if QuadratureTraining is used with complex parame…
Browse files Browse the repository at this point in the history
…ters for NNODE
  • Loading branch information
sathvikbhagavan committed Mar 28, 2024
1 parent fd9afba commit dad8815
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,

!(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) &&

Check warning on line 361 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L361

Added line #L361 was not covered by tests
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

init_params = if alg.param_estim
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)
Expand Down

0 comments on commit dad8815

Please sign in to comment.