From 0011392d4446eb7c71eea075ce179d2bc4e81a63 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Mon, 19 Jun 2023 13:43:25 +0100 Subject: [PATCH] add multioutput --- src/discretize.jl | 19 ++++++++++++++++++- src/loss_function_generation.jl | 19 ++++++++++--------- src/pinn_types.jl | 29 ++++++++++++++++------------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 6f76cc7f8..0b5cbc726 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -264,6 +264,23 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, phi.st) end + if multioutput + dvs = v.ū + acum = [0; accumulate(+, map(length, init_params))] + sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] + phimap = map(enumerate(dvs)) do (i, dv) + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Flux.Chain + dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) + else # Lux.AbstractExplicitLayer + dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + end + end |> Dict + else + phimap = nothing + end + eltypeθ = eltype(flat_init_params) if adaloss === nothing @@ -276,7 +293,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, - derivative, + phimap, derivative, strategy, eqdata, nothing, nothing, nothing, nothing) #integral = get_numeric_integral(pinnrep) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index ce7f7aa6d..928be8049 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; integrand = nothing, transformation_vars = nothing) @unpack varmap, eqdata, - phi, derivative, integral, + phi, phimap, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep @@ -77,7 +77,11 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end function build_loss_function(pinnrep, eqs) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep + @unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep + + if multioutput + phi = phimap + end _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, @@ -134,8 +138,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(phi, - ufunc, coord, + derivative(ufunc(w, coord, θ, phi), coord, [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -152,17 +155,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative(phi, - (cord_, θ_, phi_) -> - derivative(phi_, ufunc, cord_, - ε2, 1, θ_), + derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_, + ε2, 1, θ_), coord, ε1, 1, θ)] end end end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(coord, θ, phi) + @rule w => ufunc(w, coord, θ, phi)(coord, θ) end return [mx; rs; vr] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 281e562f5..0fae251df 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -239,6 +239,10 @@ mutable struct PINNRepresentation The representation of the test function of the PDE solution """ phi::Any + """ + the map of vars to chains + """ + phimap::Any """ The function used for computing the derivative """ @@ -346,12 +350,11 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) f.f(θ)(adapt(parameterless_type(θ), x)) end -ufunc(cord, θ, phi) = phi(cord, θ) -@register_symbolic ufunc(cord, θ, phi) +ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ) # the method to calculate the derivative -function numeric_derivative(phi, u, x, εs, order, θ) +function numeric_derivative(phi, x, εs, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) ε = εs[order] @@ -366,31 +369,31 @@ function numeric_derivative(phi, u, x, εs, order, θ) if order > 4 || any(x -> x != εs[1], εs) @show "me" - return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) + return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) .- - numeric_derivative(phi, u, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* + numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* _epsilon ./ 2 elseif order == 4 @show "me4" - return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) + return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ - 6 .* u(x, θ, phi) + 6 .* phi(x, θ) .- - 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 + 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 elseif order == 3 @show "me3" - return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) + return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 + phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 elseif order == 2 @show "me2" - return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 + return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 elseif order == 1 @show "me1" - return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 + return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 else error("This shouldn't happen! Got an order of $(order).") end end # Hacky workaround for metaprogramming with symbolics -@register_symbolic(numeric_derivative(phi, u, x, εs, order, θ), true, [], true) +@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true)