diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 0cd688e31..0a950a973 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -190,7 +190,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) @@ -203,7 +203,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @@ -215,7 +215,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss)