Skip to content

Commit

Permalink
float 64 flux layers
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Sep 3, 2023
1 parent 83acfed commit ae3fa74
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ x̂1 = collect(Float64, Array(u1) + 0.02 * randn(size(u1)))
time1 = vec(collect(Float64, ta0))
physsol0_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1))
chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64
chainlux = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1))
init1, re1 = destructure(chainflux)
θinit, st = Lux.setup(Random.default_rng(), chainlux)
Expand Down Expand Up @@ -116,14 +116,11 @@ x̂1 = collect(Float64, Array(u1) + 0.02 * randn(size(u1)))
time1 = vec(collect(Float64, ta0))
physsol1_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainflux1 = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1))
chainflux1 = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64
chainlux1 = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1))
init1, re1 = destructure(chainflux1)
θinit, st = Lux.setup(Random.default_rng(), chainlux1)

# weak priors call for larger NNs?
# my L2 loss also works(binds parameters)?

fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux1,
dataset = dataset,
draw_samples = 2500,
Expand Down Expand Up @@ -197,6 +194,7 @@ meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean
@test mean(abs.(physsol1_1 .- sol2flux.ensemblesol[1])) < 8e-2
@test mean(abs.(x̂1 .- sol2lux.ensemblesol[1])) < 8e-2
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2

# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - sol2flux.estimated_ode_params[1]) < abs(0.1 * p)
@test abs(p - sol2lux.estimated_ode_params[1]) < abs(0.1 * p)
Expand Down Expand Up @@ -224,7 +222,7 @@ time1 = vec(collect(Float64, ta0))
physsol2 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainflux12 = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh),
Flux.Dense(6, 1))
Flux.Dense(6, 1)) |> Flux.f64
chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
init1, re1 = destructure(chainflux12)
θinit, st = Lux.setup(Random.default_rng(), chainlux12)
Expand Down

0 comments on commit ae3fa74

Please sign in to comment.