diff --git a/test/BPINN_Tests.jl b/test/BPINN_Tests.jl index 75cb40555..f0e71c7ba 100644 --- a/test/BPINN_Tests.jl +++ b/test/BPINN_Tests.jl @@ -43,7 +43,7 @@ x̂1 = collect(Float64, Array(u1) + 0.02 * randn(size(u1))) time1 = vec(collect(Float64, ta0)) physsol0_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)] -chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) +chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 chainlux = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1)) init1, re1 = destructure(chainflux) θinit, st = Lux.setup(Random.default_rng(), chainlux) @@ -116,14 +116,11 @@ x̂1 = collect(Float64, Array(u1) + 0.02 * randn(size(u1))) time1 = vec(collect(Float64, ta0)) physsol1_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)] -chainflux1 = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) +chainflux1 = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 chainlux1 = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1)) init1, re1 = destructure(chainflux1) θinit, st = Lux.setup(Random.default_rng(), chainlux1) -# weak priors call for larger NNs? -# my L2 loss also works(binds parameters)? - fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux1, dataset = dataset, draw_samples = 2500, @@ -197,6 +194,7 @@ meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean @test mean(abs.(physsol1_1 .- sol2flux.ensemblesol[1])) < 8e-2 @test mean(abs.(x̂1 .- sol2lux.ensemblesol[1])) < 8e-2 @test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2 + # ESTIMATED ODE PARAMETERS (NN1 AND NN2) @test abs(p - sol2flux.estimated_ode_params[1]) < abs(0.1 * p) @test abs(p - sol2lux.estimated_ode_params[1]) < abs(0.1 * p) @@ -224,7 +222,7 @@ time1 = vec(collect(Float64, ta0)) physsol2 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)] chainflux12 = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh), - Flux.Dense(6, 1)) + Flux.Dense(6, 1)) |> Flux.f64 chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1)) init1, re1 = destructure(chainflux12) θinit, st = Lux.setup(Random.default_rng(), chainlux12)