Skip to content

Commit

Permalink
change to x(0)
Browse files Browse the repository at this point in the history
  • Loading branch information
xtalax committed Jul 31, 2023
1 parent 8c3ad76 commit a949169
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 178 deletions.
4 changes: 2 additions & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using ArrayInterface
import Optim
using DomainSets
using Symbolics
using Symbolics: wrap, unwrap, arguments, operation
using Symbolics: wrap, unwrap, arguments, operation, symtype
using SymbolicUtils
using SymbolicUtils.Code
using SymbolicUtils: Prewalk, Postwalk, Chain
Expand All @@ -34,7 +34,7 @@ import Optimisers
import UnPack: @unpack
import RecursiveArrayTools
import ChainRulesCore, Flux, Lux, ComponentArrays
import ChainRulesCore: @non_differentiable
import ChainRulesCore: @non_differentiable, @ignore_derivatives


RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
40 changes: 30 additions & 10 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap)
else
dxs = fill(dx, length(domains))
end
@show dxs
spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]
dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain)
for (d, dx) in zip(domains, dxs)])
Expand Down Expand Up @@ -69,11 +68,10 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining.
"""
function get_bounds end

function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeStrategy)
function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining)
dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains])
pde_args = get_argument(eqs, v)
@show pde_args

pde_lower_bounds = map(pde_args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
Expand All @@ -86,7 +84,6 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr
pde_bounds = [pde_lower_bounds, pde_upper_bounds]

bound_vars = get_variables(bcs, v)
@show bound_vars

bcs_lower_bounds = map(bound_vars) do bt
map(b -> dict_lower_bound[b], bt)
Expand All @@ -95,9 +92,32 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr
map(b -> dict_upper_bound[b], bt)
end
bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds]
@show bcs_bounds pde_bounds
[pde_bounds, bcs_bounds]
end

function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy)
dx = 1 / strategy.points
dict_span = Dict([d.variables => [
infimum(d.domain) + dx,
supremum(d.domain) - dx,
] for d in domains])

# pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains]
pde_args = get_argument(eqs, v)
pde_bounds = map(pde_args) do pde_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
end

bound_args = get_argument(bcs, v)
bcs_bounds = map(bound_args) do bound_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
end
return pde_bounds, bcs_bounds
end
# TODO: Get this to work with varmap
function get_numeric_integral(pinnrep::PINNRepresentation)
@unpack strategy, multioutput, derivative, varmap = pinnrep
Expand Down Expand Up @@ -268,15 +288,15 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
dvs = v.
acum = [0; accumulate(+, map(length, init_params))]
sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)]
phimap = map(enumerate(dvs)) do (i, dv)
phi = map(enumerate(dvs)) do (i, dv)
if (phi isa Vector && phi[1].f isa Optimisers.Restructure) ||
(!(phi isa Vector) && phi.f isa Optimisers.Restructure)
# Flux.Chain
dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]])
(coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]])
else # Lux.AbstractExplicitLayer
dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv))
(coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv))
end
end |> Dict
end
else
phimap = nothing
end
Expand All @@ -293,7 +313,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p,
param_estim, additional_loss, adaloss, v, logger,
multioutput, iteration, init_params, flat_init_params, phi,
phimap, derivative,
derivative,
strategy, eqdata, nothing, nothing, nothing, nothing)

#integral = get_numeric_integral(pinnrep)
Expand Down
55 changes: 29 additions & 26 deletions src/loss_function_generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
integrand = nothing,
transformation_vars = nothing)
@unpack varmap, eqdata,
phi, phimap, derivative, integral,
phi, derivative, integral,
multioutput, init_params, strategy, eq_params,
param_estim, default_p = pinnrep

Expand Down Expand Up @@ -53,10 +53,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
num_numbers = 0
out = map(enumerate(eq_args)) do (i, x)
if x isa Number
num_numbers += 1
fill(convert(eltypeθ, x), length(cord[[1], :]))
fill(convert(eltypeθ, x), size(cord[[1], :]))
else
cord[[i-num_numbers], :]
cord[[i], :]
end
end
if out === nothing
Expand All @@ -67,21 +66,15 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
end

full_loss_func = (cord, θ, phi, p) -> begin
coords = get_coords(cord)
@show coords
combinedcoords = reduce(vcat, coords, init = [])
@show combinedcoords
loss_function(coords, combinedcoords, θ, phi, get_ps(θ))
coords = [[nothing]]
@ignore_derivatives coords = get_coords(cord)
loss_function(coords, θ, phi, get_ps(θ))
end
return full_loss_func
end

function build_loss_function(pinnrep, eqs)
@unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep

if multioutput
phi = phimap
end
@unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep

_loss_function = build_symbolic_loss_function(pinnrep, eqs,
eq_params = eq_params,
Expand All @@ -106,15 +99,19 @@ end
function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false,
dict_transformation_vars = nothing,
transformation_vars = nothing)
@unpack varmap, eqdata, derivative, integral, flat_init_params = pinnrep
@unpack varmap, eqdata, derivative, integral, flat_init_params, multioutput = pinnrep
eltypeθ = eltype(flat_init_params)

ex_vars = get_depvars(term, varmap.depvar_ops)
ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~)

dummyvars = @variables phi(..), θ_SYMBOL, coord
if multioutput
dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL
else
dummyvars = @variables phi(..), θ_SYMBOL
end

dummyvars = unwrap.(dummyvars)
deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap)
deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput)

ch = Prewalk(Chain(deriv_rules))

Expand All @@ -124,21 +121,27 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa
ps = DestructuredArgs(varmap.ps)


args = [sym_coords, coord, θ_SYMBOL, phi, ps]
args = [sym_coords, θ_SYMBOL, phi, ps]

ex = Func(args, [], expr) |> toexpr |> _dot_

@show ex
f = @RuntimeGeneratedFunction ex
return f
end

function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap)
phi, θ, coord = dummyvars
function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput)
phi, θ = dummyvars
if symtype(phi) isa AbstractArray
phi = collect(phi)
end

dvs = get_depvars(term, varmap.depvar_ops)
@show dvs
@show eltypeθ
# Orthodox derivatives
n(w) = length(arguments(w))
rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) =>
derivative(ufunc(w, coord, θ, phi), coord,
derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ),
[get_ε(n(w),
j, eltypeθ, i) for i in 1:d],
d, θ)
Expand All @@ -155,15 +158,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative
ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2]
ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2]
[@rule $((Differential(x))((Differential(y))(w))) =>
derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_,
derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ),
ε2, 1, θ_),
coord, ε1, 1, θ)]
reducevcat(arguments(w), eltypeθ), ε1, 1, θ)]
end
end
end
end
vr = mapreduce(vcat, dvs, init = []) do w
@rule w => ufunc(w, coord, θ, phi)(coord, θ)
@rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ)
end

return [mx; rs; vr]
Expand Down
59 changes: 47 additions & 12 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,6 @@ mutable struct PINNRepresentation
The representation of the test function of the PDE solution
"""
phi::Any
"""
the map of vars to chains
"""
phimap::Any
"""
The function used for computing the derivative
"""
Expand Down Expand Up @@ -350,9 +346,6 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ)
f.f(θ)(adapt(parameterless_type(θ), x))
end

ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ)


# the method to calculate the derivative
function numeric_derivative(phi, x, εs, order, θ)
_type = parameterless_type(ComponentArrays.getdata(θ))
Expand All @@ -368,32 +361,74 @@ function numeric_derivative(phi, x, εs, order, θ)
# if order 1, this is trivially true

if order > 4 || any(x -> x != εs[1], εs)
@show "me"
return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ)
.-
numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .*
_epsilon ./ 2
elseif order == 4
@show "me4"
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ)
.+
6 .* phi(x, θ)
.-
4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4
elseif order == 3
@show "me3"
return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ)
-
phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2
elseif order == 2
@show "me2"
return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2
elseif order == 1
@show "me1"
return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2
else
error("This shouldn't happen! Got an order of $(order).")
end
end
# Hacky workaround for metaprogramming with symbolics
@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true)

function ufunc(u, phi, v)
if symtype(phi) isa AbstractArray
return phi[findfirst(w -> isequal(operation(w), operation(u)), v.ū)]
else
return phi
end
end

#=
_vcat(x::Number...) = vcat(x...)
_vcat(x::AbstractArray{<:Number}...) = vcat(x...)
function _vcat(x::Union{Number, AbstractArray{<:Number}}...)
example = first(Iterators.filter(e -> !(e isa Number), x))
dims = (1, size(example)[2:end]...)
x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x)
_vcat(x...)
end
_vcat(x...) = vcat(x...)
https://github.com/SciML/NeuralPDE.jl/pull/627/files
=#



function reducevcat(vector, eltypeθ)
if all(x -> x isa Number, vector)
return vector
else
z = findfirst(x -> !(x isa Number), vector)
return rvcat(vector, vector[z], eltypeθ)
end
end

function rvcat(example, vector, eltypeθ)
isnothing(vector) && return [[nothing]]
return mapreduce(hcat, vector) do x
if x isa Number
out = typeof(example)(fill(convert(eltypeθ, x), size(example)))
out
else
out = x
out
end
end
end

@register_symbolic(rvcat(vector, example, eltypeθ), true, [], true)
18 changes: 4 additions & 14 deletions src/symbolic_utilities.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
using Base.Broadcast

function get_limits(domain)
if domain isa AbstractInterval
return [leftendpoint(domain)], [rightendpoint(domain)]
elseif domain isa ProductDomain
return collect(map(leftendpoint, DomainSets.components(domain))),
collect(map(rightendpoint, DomainSets.components(domain)))
end
end

θ = gensym("θ")

"""
Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true`
Expand All @@ -36,8 +25,8 @@ dottable_(x::Phi) = false
_dot_(x) = x
function _dot_(x::Expr)
dotargs = Base.mapany(_dot_, x.args)
nodot = [:phi, Symbol("NeuralPDE.numeric_derivative")]
if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] !== s, nodot)
nodot = [:phi, Symbol("NeuralPDE.numeric_derivative"), NeuralPDE.rvcat]
if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] != s, nodot)
Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...))
elseif x.head === :comparison
Expr(:comparison,
Expand Down Expand Up @@ -217,7 +206,6 @@ function get_argument(eqs, v::VariableMap)
f_vars = filter(x -> !isempty(x), _vars)
map(first, f_vars)
end
@show vars
args_ = map(vars) do _vars
seen = []
filter(reduce(vcat, arguments.(_vars), init = [])) do x
Expand Down Expand Up @@ -252,3 +240,5 @@ function get_number(eqs, v::VariableMap)
args = get_argument(eqs, v)
return map(arg -> filter(x -> x isa Number, arg), args)
end

sym_op(u) = Symbol(operation(u))
Loading

0 comments on commit a949169

Please sign in to comment.