diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 3bbf1afea..087b9d41c 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) BNNODE(chain, Kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd, dataset, physdt, MCMCkwargs, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d367bf8b6..2ba1de25b 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays +using Lux: FromFluxAdaptor using ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 6f3014925..c86c87599 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMCkwargs = (n_leapfrog = 30,), progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 3f6bf8f0f..0c9d1323d 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -42,7 +42,7 @@ end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index b9c46d346..077edc659 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -90,7 +90,7 @@ end function NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, batch = false, param_estim = false, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 69116c4da..5d7ffc9b1 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput