diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 3e0d3fe1b..9bbd316a3 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -22,7 +22,7 @@ using ArrayInterface import Optim using DomainSets using Symbolics -using Symbolics: wrap, unwrap, arguments, operation +using Symbolics: wrap, unwrap, arguments, operation, symtype using SymbolicUtils using SymbolicUtils.Code using SymbolicUtils: Prewalk, Postwalk, Chain @@ -34,7 +34,7 @@ import Optimisers import UnPack: @unpack import RecursiveArrayTools import ChainRulesCore, Flux, Lux, ComponentArrays -import ChainRulesCore: @non_differentiable +import ChainRulesCore: @non_differentiable, @ignore_derivatives RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/discretize.jl b/src/discretize.jl index 0b5cbc726..aa6fe2f52 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -15,7 +15,6 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) else dxs = fill(dx, length(domains)) end - @show dxs spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) @@ -69,11 +68,10 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeStrategy) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) pde_args = get_argument(eqs, v) - @show pde_args pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -86,7 +84,6 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr pde_bounds = [pde_lower_bounds, pde_upper_bounds] bound_vars = get_variables(bcs, v) - @show bound_vars bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -95,9 +92,32 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr map(b -> dict_upper_bound[b], bt) end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - @show bcs_bounds pde_bounds [pde_bounds, bcs_bounds] end + +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy) + dx = 1 / strategy.points + dict_span = Dict([d.variables => [ + infimum(d.domain) + dx, + supremum(d.domain) - dx, + ] for d in domains]) + + # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] + pde_args = get_argument(eqs, v) + pde_bounds = map(pde_args) do pde_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + + bound_args = get_argument(bcs, v) + bcs_bounds = map(bound_args) do bound_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + return pde_bounds, bcs_bounds +end # TODO: Get this to work with varmap function get_numeric_integral(pinnrep::PINNRepresentation) @unpack strategy, multioutput, derivative, varmap = pinnrep @@ -268,15 +288,15 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, dvs = v.ū acum = [0; accumulate(+, map(length, init_params))] sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - phimap = map(enumerate(dvs)) do (i, dv) + phi = map(enumerate(dvs)) do (i, dv) if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || (!(phi isa Vector) && phi.f isa Optimisers.Restructure) # Flux.Chain - dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) + (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) else # Lux.AbstractExplicitLayer - dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) end - end |> Dict + end else phimap = nothing end @@ -293,7 +313,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, - phimap, derivative, + derivative, strategy, eqdata, nothing, nothing, nothing, nothing) #integral = get_numeric_integral(pinnrep) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 928be8049..79e17bbb7 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; integrand = nothing, transformation_vars = nothing) @unpack varmap, eqdata, - phi, phimap, derivative, integral, + phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep @@ -53,10 +53,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; num_numbers = 0 out = map(enumerate(eq_args)) do (i, x) if x isa Number - num_numbers += 1 - fill(convert(eltypeθ, x), length(cord[[1], :])) + fill(convert(eltypeθ, x), size(cord[[1], :])) else - cord[[i-num_numbers], :] + cord[[i], :] end end if out === nothing @@ -67,21 +66,15 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end full_loss_func = (cord, θ, phi, p) -> begin - coords = get_coords(cord) - @show coords - combinedcoords = reduce(vcat, coords, init = []) - @show combinedcoords - loss_function(coords, combinedcoords, θ, phi, get_ps(θ)) + coords = [[nothing]] + @ignore_derivatives coords = get_coords(cord) + loss_function(coords, θ, phi, get_ps(θ)) end return full_loss_func end function build_loss_function(pinnrep, eqs) - @unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep - - if multioutput - phi = phimap - end + @unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, @@ -106,15 +99,19 @@ end function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false, dict_transformation_vars = nothing, transformation_vars = nothing) - @unpack varmap, eqdata, derivative, integral, flat_init_params = pinnrep + @unpack varmap, eqdata, derivative, integral, flat_init_params, multioutput = pinnrep eltypeθ = eltype(flat_init_params) ex_vars = get_depvars(term, varmap.depvar_ops) - ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - dummyvars = @variables phi(..), θ_SYMBOL, coord + if multioutput + dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL + else + dummyvars = @variables phi(..), θ_SYMBOL + end + dummyvars = unwrap.(dummyvars) - deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) + deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) ch = Prewalk(Chain(deriv_rules)) @@ -124,21 +121,27 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ps = DestructuredArgs(varmap.ps) - args = [sym_coords, coord, θ_SYMBOL, phi, ps] + args = [sym_coords, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ + + @show ex f = @RuntimeGeneratedFunction ex return f end -function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) - phi, θ, coord = dummyvars +function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + phi, θ = dummyvars + if symtype(phi) isa AbstractArray + phi = collect(phi) + end + dvs = get_depvars(term, varmap.depvar_ops) - @show dvs + @show eltypeθ # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(ufunc(w, coord, θ, phi), coord, + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -155,15 +158,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_, + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), ε2, 1, θ_), - coord, ε1, 1, θ)] + reducevcat(arguments(w), eltypeθ), ε1, 1, θ)] end end end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(w, coord, θ, phi)(coord, θ) + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) end return [mx; rs; vr] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 0fae251df..a4a03ddc5 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -239,10 +239,6 @@ mutable struct PINNRepresentation The representation of the test function of the PDE solution """ phi::Any - """ - the map of vars to chains - """ - phimap::Any """ The function used for computing the derivative """ @@ -350,9 +346,6 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) f.f(θ)(adapt(parameterless_type(θ), x)) end -ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ) - - # the method to calculate the derivative function numeric_derivative(phi, x, εs, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) @@ -368,28 +361,23 @@ function numeric_derivative(phi, x, εs, order, θ) # if order 1, this is trivially true if order > 4 || any(x -> x != εs[1], εs) - @show "me" return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) .- numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* _epsilon ./ 2 elseif order == 4 - @show "me4" return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .- 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 elseif order == 3 - @show "me3" return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 elseif order == 2 - @show "me2" return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 elseif order == 1 - @show "me1" return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 else error("This shouldn't happen! Got an order of $(order).") @@ -397,3 +385,50 @@ function numeric_derivative(phi, x, εs, order, θ) end # Hacky workaround for metaprogramming with symbolics @register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true) + +function ufunc(u, phi, v) + if symtype(phi) isa AbstractArray + return phi[findfirst(w -> isequal(operation(w), operation(u)), v.ū)] + else + return phi + end +end + +#= +_vcat(x::Number...) = vcat(x...) +_vcat(x::AbstractArray{<:Number}...) = vcat(x...) +function _vcat(x::Union{Number, AbstractArray{<:Number}}...) + example = first(Iterators.filter(e -> !(e isa Number), x)) + dims = (1, size(example)[2:end]...) + x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x) + _vcat(x...) +end +_vcat(x...) = vcat(x...) +https://github.com/SciML/NeuralPDE.jl/pull/627/files +=# + + + +function reducevcat(vector, eltypeθ) + if all(x -> x isa Number, vector) + return vector + else + z = findfirst(x -> !(x isa Number), vector) + return rvcat(vector, vector[z], eltypeθ) + end +end + +function rvcat(example, vector, eltypeθ) + isnothing(vector) && return [[nothing]] + return mapreduce(hcat, vector) do x + if x isa Number + out = typeof(example)(fill(convert(eltypeθ, x), size(example))) + out + else + out = x + out + end + end +end + +@register_symbolic(rvcat(vector, example, eltypeθ), true, [], true) \ No newline at end of file diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 15cbb5e5c..fc81f3020 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -1,16 +1,5 @@ using Base.Broadcast -function get_limits(domain) - if domain isa AbstractInterval - return [leftendpoint(domain)], [rightendpoint(domain)] - elseif domain isa ProductDomain - return collect(map(leftendpoint, DomainSets.components(domain))), - collect(map(rightendpoint, DomainSets.components(domain))) - end -end - -θ = gensym("θ") - """ Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true` @@ -36,8 +25,8 @@ dottable_(x::Phi) = false _dot_(x) = x function _dot_(x::Expr) dotargs = Base.mapany(_dot_, x.args) - nodot = [:phi, Symbol("NeuralPDE.numeric_derivative")] - if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] !== s, nodot) + nodot = [:phi, Symbol("NeuralPDE.numeric_derivative"), NeuralPDE.rvcat] + if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] != s, nodot) Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) elseif x.head === :comparison Expr(:comparison, @@ -217,7 +206,6 @@ function get_argument(eqs, v::VariableMap) f_vars = filter(x -> !isempty(x), _vars) map(first, f_vars) end - @show vars args_ = map(vars) do _vars seen = [] filter(reduce(vcat, arguments.(_vars), init = [])) do x @@ -252,3 +240,5 @@ function get_number(eqs, v::VariableMap) args = get_argument(eqs, v) return map(arg -> filter(x -> x isa Number, arg), args) end + +sym_op(u) = Symbol(operation(u)) \ No newline at end of file diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index 4d6a31c27..b4fb8018d 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -435,117 +435,117 @@ end # plot(p1,p2) end -# ## Example 5, 2d wave equation, neumann boundary condition -# @testset "Example 5, 2d wave equation, neumann boundary condition" begin -# #here we use low level api for build solution -# @parameters x, t -# @variables u(..) -# Dxx = Differential(x)^2 -# Dtt = Differential(t)^2 -# Dt = Differential(t) - -# #2D PDE -# C = 1 -# eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) - -# # Initial and boundary conditions -# bcs = [u(0, t) ~ 0.0,# for all t > 0 -# u(1, t) ~ 0.0,# for all t > 0 -# u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 -# Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] - -# # Space and time domains -# domains = [x ∈ Interval(0.0, 1.0), -# t ∈ Interval(0.0, 1.0)] -# @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) - -# # Neural network -# chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) -# phi = NeuralPDE.Phi(chain) -# derivative = NeuralPDE.numeric_derivative - -# quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), -# reltol = 1e-3, abstol = 1e-3, -# maxiters = 50, batch = 100) - -# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) -# prob = NeuralPDE.discretize(pde_system, discretization) - -# cb_ = function (p, l) -# println("loss: ", l) -# println("losses: ", map(l -> l(p), loss_functions)) -# return false -# end - -# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) - -# dx = 0.1 -# xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] -# function analytic_sol_func(x, t) -# sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) -# end - -# u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], -# (length(xs), length(ts))) -# u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], -# (length(xs), length(ts))) - -# @test u_predict≈u_real atol=0.1 - -# # diff_u = abs.(u_predict .- u_real) -# # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); -# # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); -# # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); -# # plot(p1,p2,p3) -# end -# ## Example 6, pde with mixed derivative -# @testset "Example 6, pde with mixed derivative" begin -# @parameters x y -# @variables u(..) -# Dxx = Differential(x)^2 -# Dyy = Differential(y)^2 -# Dx = Differential(x) -# Dy = Differential(y) - -# eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 - -# # Initial and boundary conditions -# bcs = [u(x, 0) ~ x, -# Dy(u(x, 0)) ~ x, -# u(x, 0) ~ Dy(u(x, 0))] - -# # Space and time domains -# domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] - -# quadrature_strategy = NeuralPDE.QuadratureTraining() -# # Neural network -# inner = 20 -# chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), -# Lux.Dense(inner, 1)) - -# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) -# @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) - -# prob = NeuralPDE.discretize(pde_system, discretization) - -# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) -# @show res.original - -# phi = discretization.phi - -# analytic_sol_func(x, y) = x + x * y + y^2 / 2 -# xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] - -# u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], -# (length(xs), length(ys))) -# u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], -# (length(xs), length(ys))) -# diff_u = abs.(u_predict .- u_real) - -# @test u_predict≈u_real rtol=0.1 - -# # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); -# # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); -# # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); -# # plot(p1,p2,p3) -# end +## Example 5, 2d wave equation, neumann boundary condition +@testset "Example 5, 2d wave equation, neumann boundary condition" begin + #here we use low level api for build solution + @parameters x, t + @variables u(..) + Dxx = Differential(x)^2 + Dtt = Differential(t)^2 + Dt = Differential(t) + + #2D PDE + C = 1 + eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) + + # Initial and boundary conditions + bcs = [u(0, t) ~ 0.0,# for all t > 0 + u(1, t) ~ 0.0,# for all t > 0 + u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 + Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), + t ∈ Interval(0.0, 1.0)] + @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) + + # Neural network + chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) + phi = NeuralPDE.Phi(chain) + derivative = NeuralPDE.numeric_derivative + + quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), + reltol = 1e-3, abstol = 1e-3, + maxiters = 50, batch = 100) + + discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) + prob = NeuralPDE.discretize(pde_system, discretization) + + cb_ = function (p, l) + println("loss: ", l) + println("losses: ", map(l -> l(p), loss_functions)) + return false + end + + res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) + + dx = 0.1 + xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] + function analytic_sol_func(x, t) + sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) + end + + u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], + (length(xs), length(ts))) + u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], + (length(xs), length(ts))) + + @test u_predict≈u_real atol=0.1 + + # diff_u = abs.(u_predict .- u_real) + # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); + # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); + # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); + # plot(p1,p2,p3) +end +## Example 6, pde with mixed derivative +@testset "Example 6, pde with mixed derivative" begin + @parameters x y + @variables u(..) + Dxx = Differential(x)^2 + Dyy = Differential(y)^2 + Dx = Differential(x) + Dy = Differential(y) + + eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 + + # Initial and boundary conditions + bcs = [u(x, 0) ~ x, + Dy(u(x, 0)) ~ x, + u(x, 0) ~ Dy(u(x, 0))] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] + + quadrature_strategy = NeuralPDE.QuadratureTraining() + # Neural network + inner = 20 + chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), + Lux.Dense(inner, 1)) + + discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) + @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + + prob = NeuralPDE.discretize(pde_system, discretization) + + res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) + @show res.original + + phi = discretization.phi + + analytic_sol_func(x, y) = x + x * y + y^2 / 2 + xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] + + u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], + (length(xs), length(ys))) + u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], + (length(xs), length(ys))) + diff_u = abs.(u_predict .- u_real) + + @test u_predict≈u_real rtol=0.1 + + # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); + # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); + # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); + # plot(p1,p2,p3) +end