Skip to content

Commit

Permalink
add multioutput
Browse files Browse the repository at this point in the history
  • Loading branch information
xtalax committed Jun 19, 2023
1 parent 17d0150 commit 0011392
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
19 changes: 18 additions & 1 deletion src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,23 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
phi.st)
end

if multioutput
dvs = v.
acum = [0; accumulate(+, map(length, init_params))]
sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)]
phimap = map(enumerate(dvs)) do (i, dv)
if (phi isa Vector && phi[1].f isa Optimisers.Restructure) ||
(!(phi isa Vector) && phi.f isa Optimisers.Restructure)
# Flux.Chain
dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]])
else # Lux.AbstractExplicitLayer
dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv))
end
end |> Dict
else
phimap = nothing
end

eltypeθ = eltype(flat_init_params)

if adaloss === nothing
Expand All @@ -276,7 +293,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p,
param_estim, additional_loss, adaloss, v, logger,
multioutput, iteration, init_params, flat_init_params, phi,
derivative,
phimap, derivative,
strategy, eqdata, nothing, nothing, nothing, nothing)

#integral = get_numeric_integral(pinnrep)
Expand Down
19 changes: 10 additions & 9 deletions src/loss_function_generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
integrand = nothing,
transformation_vars = nothing)
@unpack varmap, eqdata,
phi, derivative, integral,
phi, phimap, derivative, integral,
multioutput, init_params, strategy, eq_params,
param_estim, default_p = pinnrep

Expand Down Expand Up @@ -77,7 +77,11 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
end

function build_loss_function(pinnrep, eqs)
@unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep
@unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep

if multioutput
phi = phimap
end

_loss_function = build_symbolic_loss_function(pinnrep, eqs,
eq_params = eq_params,
Expand Down Expand Up @@ -134,8 +138,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative
# Orthodox derivatives
n(w) = length(arguments(w))
rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) =>
derivative(phi,
ufunc, coord,
derivative(ufunc(w, coord, θ, phi), coord,
[get_ε(n(w),
j, eltypeθ, i) for i in 1:d],
d, θ)
Expand All @@ -152,17 +155,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative
ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2]
ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2]
[@rule $((Differential(x))((Differential(y))(w))) =>
derivative(phi,
(cord_, θ_, phi_) ->
derivative(phi_, ufunc, cord_,
ε2, 1, θ_),
derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_,
ε2, 1, θ_),
coord, ε1, 1, θ)]
end
end
end
end
vr = mapreduce(vcat, dvs, init = []) do w
@rule w => ufunc(coord, θ, phi)
@rule w => ufunc(w, coord, θ, phi)(coord, θ)
end

return [mx; rs; vr]
Expand Down
29 changes: 16 additions & 13 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ mutable struct PINNRepresentation
The representation of the test function of the PDE solution
"""
phi::Any
"""
the map of vars to chains
"""
phimap::Any
"""
The function used for computing the derivative
"""
Expand Down Expand Up @@ -346,12 +350,11 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ)
f.f(θ)(adapt(parameterless_type(θ), x))
end

ufunc(cord, θ, phi) = phi(cord, θ)
@register_symbolic ufunc(cord, θ, phi)
ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ)


# the method to calculate the derivative
function numeric_derivative(phi, u, x, εs, order, θ)
function numeric_derivative(phi, x, εs, order, θ)
_type = parameterless_type(ComponentArrays.getdata(θ))

ε = εs[order]
Expand All @@ -366,31 +369,31 @@ function numeric_derivative(phi, u, x, εs, order, θ)

if order > 4 || any(x -> x != εs[1], εs)
@show "me"
return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ)
return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ)
.-
numeric_derivative(phi, u, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .*
numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .*
_epsilon ./ 2
elseif order == 4
@show "me4"
return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi)
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ)
.+
6 .* u(x, θ, phi)
6 .* phi(x, θ)
.-
4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4
4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4
elseif order == 3
@show "me3"
return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi)
return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ)
-
u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2
phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2
elseif order == 2
@show "me2"
return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2
return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2
elseif order == 1
@show "me1"
return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2
return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2
else
error("This shouldn't happen! Got an order of $(order).")
end
end
# Hacky workaround for metaprogramming with symbolics
@register_symbolic(numeric_derivative(phi, u, x, εs, order, θ), true, [], true)
@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true)

0 comments on commit 0011392

Please sign in to comment.