diff --git a/src/neural_de.jl b/src/neural_de.jl index 12ee3533d..3a3db46f5 100644 --- a/src/neural_de.jl +++ b/src/neural_de.jl @@ -254,15 +254,16 @@ function (n::NeuralDAE)(u_du::Tuple, p, st) nn_out = model(vcat(u, du), p) alg_out = n.constraints_model(u, p, t) iter_nn, iter_const = 0, 0 - map(n.differential_vars) do isdiff + res = map(n.differential_vars) do isdiff if isdiff iter_nn += 1 - selectdim(nn_out, 1, iter_nn) + nn_out[iter_nn] else iter_const += 1 - selectdim(alg_out, 1, iter_const) + alg_out[iter_const] end end + return res end prob = DAEProblem{false}(f, du0, u0, n.tspan, p; n.differential_vars) diff --git a/test/neural_dae.jl b/test/neural_dae.jl index 40a8767dc..b5352fff9 100644 --- a/test/neural_dae.jl +++ b/test/neural_dae.jl @@ -1,4 +1,5 @@ -using ComponentArrays, DiffEqFlux, Zygote, Optimization, OrdinaryDiffEq, Random +using ComponentArrays, + DiffEqFlux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random #A desired MWE for now, not a test yet. @@ -27,6 +28,7 @@ tspan = (0.0, 10.0) ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, DImplicitEuler(); differential_vars = [true, true, false]) ps, st = Lux.setup(Xoshiro(0), ndae) +ps = ComponentArray(ps) truedu0 = similar(u₀) ndae((u₀, truedu0), ps, st) @@ -36,13 +38,11 @@ predict_n_dae(p) = first(ndae(u₀, p, st)) function loss(p) pred = predict_n_dae(p) loss = sum(abs2, sol .- pred) - loss, pred + return loss, pred end -p = p .+ rand(3) .* p - optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) +optprob = Optimization.OptimizationProblem(optfunc, ps) res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001)) # Same stuff with Lux diff --git a/test/neural_gde.jl b/test/neural_gde.jl index 743d54821..e8e0e9a2b 100644 --- a/test/neural_gde.jl +++ b/test/neural_gde.jl @@ -1,5 +1,5 @@ using DiffEqFlux, ComponentArrays, GeometricFlux, GraphSignals, OrdinaryDiffEq, Random, - Test, OptimizationOptimisers, Optimization + Test, OptimizationOptimisers, Optimization, Statistics import Flux # Fully Connected Graph