Skip to content

Commit

Permalink
implement DeepONet, refactor pinoode
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Apr 25, 2024
1 parent 8a98880 commit 2cc1d1f
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 376 deletions.
3 changes: 2 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ include("adaptive_losses.jl")
include("ode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("neural_operators.jl")
include("pino_ode_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
Expand All @@ -58,7 +59,7 @@ include("PDE_BPINN.jl")
include("dgm.jl")


export NNODE, NNDAE, PINOODE, TRAINSET, EquationSolving, OperatorLearning
export NNODE, NNDAE, PINOODE, DeepONet
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
Expand Down
47 changes: 47 additions & 0 deletions src/neural_operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#TODO: Add docstrings
"""
DeepONet(branch,trunk)
"""
struct DeepONet{} <: Lux.AbstractExplicitLayer
branch::Lux.AbstractExplicitLayer
trunk::Lux.AbstractExplicitLayer
end

function Lux.setup(rng::AbstractRNG, l::DeepONet)
branch, trunk = l.branch, l.trunk
θ_branch, st_branch = Lux.setup(rng, branch)
θ_trunk, st_trunk = Lux.setup(rng, trunk)
θ = (branch = θ_branch, trunk = θ_trunk)
st = (branch = st_branch, trunk = st_trunk)
θ, st
end

# function Lux.initialparameters(rng::AbstractRNG, e::DeepONet)
# code
# end

Lux.initialstates(::AbstractRNG, ::DeepONet) = NamedTuple()

"""
example:
branch = Lux.Chain(Lux.Dense(1, 32, Lux.σ), Lux.Dense(32, 1))
trunk = Lux.Chain(Lux.Dense(1, 32, Lux.σ), Lux.Dense(32, 1))
a = rand(1, 100, 10)
t = rand(1, 1, 10)
x = (branch = a, trunk = t)
deeponet = DeepONet(branch, trunk)
θ, st = Lux.setup(Random.default_rng(), deeponet)
y = deeponet(x, θ, st)
"""
@inline function (f::DeepONet)(x::NamedTuple, θ, st::NamedTuple)
parameters, cord = x.branch, x.trunk
branch, trunk = f.branch, f.trunk
st_branch, st_trunk = st.branch, st.trunk
θ_branch, θ_trunk = θ.branch, θ.trunk
out_b, st_b = branch(parameters, θ_branch, st_branch)
out_t, st_t = trunk(cord, θ_trunk, st_trunk)
out = out_b' * out_t
return out, (branch = st_b, trunk = st_t)
end
Loading

0 comments on commit 2cc1d1f

Please sign in to comment.