Skip to content

Commit

Permalink
fix: checking complex type in the parameters of nn
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jul 3, 2024
1 parent 6627aa0 commit c7157d9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte
using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using Lux: FromFluxAdaptor, recursive_eltype
using ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
3 changes: 1 addition & 2 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex ||
eltype(eltype(init_params).types[2]) <: Complex) &&
(recursive_eltype(init_params) <: Complex &&
alg.strategy isa QuadratureTraining) &&
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

Expand Down

0 comments on commit c7157d9

Please sign in to comment.