diff --git a/src/ode_solve.jl b/src/ode_solve.jl index bb417b649..eb5ae942e 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -358,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, !(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported") phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + ((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) && + error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") init_params = if alg.param_estim ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)