Skip to content

Commit

Permalink
refactor: use FromFluxAdaptor for converting Flux to Lux as Lux.trans…
Browse files Browse the repository at this point in the history
…form is deprecated
  • Loading branch information
sathvikbhagavan committed Mar 22, 2024
1 parent c74384c commit c13cd61
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
Expand Down
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte
using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
2 changes: 1 addition & 1 deletion src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)

!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
Expand Down
2 changes: 1 addition & 1 deletion src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end

function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end

Expand Down
2 changes: 1 addition & 1 deletion src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs)
end

Expand Down
4 changes: 2 additions & 2 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
else
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
end
if phi === nothing
if multioutput
Expand Down Expand Up @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
else
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
end
if phi === nothing
if multioutput
Expand Down

0 comments on commit c13cd61

Please sign in to comment.