Skip to content

Commit

Permalink
Add LagrangeBasis (#103)
Browse files Browse the repository at this point in the history
* add LagrangeBasis

* add LagrangeBasis

* fix example

* add test for example

* format

* put polynomials into LagrangeBasis

* add unit tests for LagrangeBasis

* format

* clarify docstrings

* add least squares test with LagrangeBasis
  • Loading branch information
JoshuaLampert authored Nov 6, 2024
1 parent 45e52be commit 366e9b8
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 12 deletions.
28 changes: 28 additions & 0 deletions examples/interpolation/interpolation_2d_Lagrange_basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using KernelInterpolation
using Plots

# interpolate Franke function
function f(x)
0.75 * exp(-0.25 * ((9 * x[1] - 2)^2 + (9 * x[2] - 2)^2)) +
0.75 * exp(-(9 * x[1] + 1)^2 / 49 - (9 * x[2] + 1) / 10) +
0.5 * exp(-0.25 * ((9 * x[1] - 7)^2 + (9 * x[2] - 3)^2)) -
0.2 * exp(-(9 * x[1] - 4)^2 - (9 * x[2] - 7)^2)
end

n = 50
nodeset = random_hypercube(n; dim = 2)
values = f.(nodeset)

kernel = ThinPlateSplineKernel{dim(nodeset)}()
# Computing the Lagrange basis is expensive, but interpolation with it is cheap
basis = LagrangeBasis(nodeset, kernel)
itp = interpolate(basis, values)

N = 20
many_nodes = homogeneous_hypercube(N; dim = 2)

p1 = plot(many_nodes, itp)
plot!(p1, many_nodes, f, st = :surface)

p2 = plot(itp, st = :heatmap)
plot(p1, p2, layout = (2, 1))
7 changes: 5 additions & 2 deletions src/KernelInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module KernelInterpolation

using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
using ForwardDiff: ForwardDiff
using LinearAlgebra: Symmetric, norm, tr, muladd, dot, diagind
using LinearAlgebra: Symmetric, I, norm, tr, muladd, dot, diagind
using Printf: @sprintf
using ReadVTK: VTKFile, get_points, get_point_data, get_data
using RecipesBase: RecipesBase, @recipe, @series
Expand All @@ -29,6 +29,9 @@ using TypedPolynomials: Variable, monomials, degree
using WriteVTK: WriteVTK, vtk_grid, paraview_collection, MeshCell, VTKCellTypes,
CollectionFile

# Define the AbstractInterpolation already here because they are needed in basis.jl
abstract type AbstractInterpolation{Basis, Dim, RealT} end

include("kernels/kernels.jl")
include("nodes.jl")
include("basis.jl")
Expand All @@ -49,7 +52,7 @@ export GaussKernel, MultiquadricKernel, InverseMultiquadricKernel,
RadialCharacteristicKernel, MaternKernel, Matern12Kernel, Matern32Kernel,
Matern52Kernel, Matern72Kernel, RieszKernel,
TransformationKernel, ProductKernel, SumKernel
export StandardBasis
export StandardBasis, LagrangeBasis
export phi, Phi, order
export PartialDerivative, Gradient, Laplacian, EllipticOperator
export PoissonEquation, EllipticEquation, AdvectionEquation, HeatEquation,
Expand Down
51 changes: 51 additions & 0 deletions src/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,54 @@ struct StandardBasis{Kernel} <: AbstractBasis
end

Base.getindex(basis::StandardBasis, i) = x -> basis.kernel(x, centers(basis)[i])

@doc raw"""
LagrangeBasis(centers, kernel, m = order(kernel))
The Lagrange (or cardinal) basis with respect to a kernel and a [`NodeSet`](@ref) of `centers`. This basis
already includes polynomial augmentation of degree `m` defaulting to `order(kernel)`. The basis functions are given such that
```math
b_j(x_i) = \delta_{ij},
```
which means that the [`kernel_matrix`](@ref) of this basis is the identity matrix making it suitable for interpolation. Since the
basis already includes polynomials no additional polynomial augmentation is needed for interpolation with this basis.
"""
struct LagrangeBasis{Kernel, I <: AbstractInterpolation, Monomials, PolyVars} <:
AbstractBasis
centers::NodeSet
kernel::Kernel
basis_functions::Vector{I}
ps::Monomials
xx::PolyVars
function LagrangeBasis(centers::NodeSet, kernel::Kernel;
m = order(kernel)) where {Kernel}
if dim(kernel) != dim(centers)
throw(DimensionMismatch("The dimension of the kernel and the centers must be the same"))
end
K = length(centers)
values = zeros(K)
values[1] = 1.0
b = interpolate(centers, values, kernel; m = m)
basis_functions = Vector{typeof(b)}(undef, K)
basis_functions[1] = b
for i in 2:K
values[i - 1] = 0.0
values[i] = 1.0
basis_functions[i] = interpolate(centers, values, kernel; m = m)
end
# All basis functions have same polynomials
ps = first(basis_functions).ps
xx = first(basis_functions).xx
new{typeof(kernel), eltype(basis_functions), typeof(ps), typeof(xx)}(centers,
kernel,
basis_functions,
ps, xx)
end
end

Base.getindex(basis::LagrangeBasis, i) = x -> basis.basis_functions[i](x)
Base.collect(basis::LagrangeBasis) = basis.basis_functions
# Polynomials are already inherently defined included in the basis
order(::LagrangeBasis) = 0
13 changes: 5 additions & 8 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
abstract type AbstractInterpolation{Basis, Dim, RealT} end

@doc raw"""
Interpolation
Expand Down Expand Up @@ -195,16 +193,15 @@ end
# Evaluate interpolant
function (itp::Interpolation)(x)
s = 0
kernel = interpolation_kernel(itp)
xis = centers(itp)
bas = basis(itp)
c = kernel_coefficients(itp)
d = polynomial_coefficients(itp)
ps = polynomial_basis(itp)
xx = polyvars(itp)
for j in eachindex(c)
s += c[j] * kernel(x, xis[j])
s += c[j] * bas[j](x)
end

d = polynomial_coefficients(itp)
ps = polynomial_basis(itp)
xx = polyvars(itp)
for k in eachindex(d)
s += d[k] * ps[k](xx => x)
end
Expand Down
17 changes: 15 additions & 2 deletions src/kernel_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ end
interpolation_matrix(basis, ps, regularization)
Return the interpolation matrix for the `basis`, polynomials `ps`, and `regularization`.
The interpolation matrix is defined as
For the [`StandardBasis`](@ref), the interpolation matrix is defined as
```math
A = \begin{pmatrix}K & P\\P^T & 0\end{pmatrix},
```
Expand All @@ -80,6 +80,12 @@ function interpolation_matrix(basis::AbstractBasis, ps,
return Symmetric(system_matrix)
end

# This should be the same as `kernel_matrix(basis)`
function interpolation_matrix(::LagrangeBasis, ps,
::AbstractRegularization = NoRegularization())
return I
end

function interpolation_matrix(centers::NodeSet, kernel::AbstractKernel, ps,
regularization::AbstractRegularization = NoRegularization())
interpolation_matrix(StandardBasis(centers, kernel), ps, regularization)
Expand All @@ -90,7 +96,7 @@ end
least_squares_matrix(centers, nodeset, kernel, ps, regularization = NoRegularization())
Return the least squares matrix for the `basis`, `nodeset`, polynomials `ps`, and `regularization`.
The least squares matrix is defined as
For the [`StandardBasis`](@ref), the least squares matrix is defined as
```math
A = \begin{pmatrix}K & P_1\\P_2^T & 0\end{pmatrix},
```
Expand All @@ -110,6 +116,13 @@ function least_squares_matrix(basis::AbstractBasis, nodeset::NodeSet, ps,
return system_matrix
end

function least_squares_matrix(basis::LagrangeBasis, nodeset::NodeSet, ps,
regularization::AbstractRegularization = NoRegularization())
k_matrix = kernel_matrix(basis, nodeset)
regularize!(k_matrix, regularization)
return k_matrix
end

function least_squares_matrix(centers::NodeSet, nodeset::NodeSet, kernel::AbstractKernel,
ps,
regularization::AbstractRegularization = NoRegularization())
Expand Down
9 changes: 9 additions & 0 deletions test/test_examples_interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ end
ns=5:10)
end

@testitem "interpolation_2d_Lagrange_basis.jl" setup=[
Setup,
AdditionalImports,
InterpolationExamples
] begin
@test_include_example(joinpath(EXAMPLES_DIR, "interpolation_2d_Lagrange_basis.jl"),
l2=0.40362797382569787, linf=0.06797693848658759)
end

@testitem "least_squares_2d.jl" setup=[
Setup,
AdditionalImports,
Expand Down
54 changes: 54 additions & 0 deletions test/test_unit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,37 @@ end
for (i, b) in enumerate(basis)
@test b.(nodeset) == basis_functions[i].(nodeset)
end

kernel = ThinPlateSplineKernel{dim(nodeset)}()
basis = @test_nowarn LagrangeBasis(nodeset, kernel)
@test_throws DimensionMismatch LagrangeBasis(nodeset,
GaussKernel{1}(shape_parameter = 0.5))
@test_nowarn println(basis)
@test_nowarn display(basis)
A = kernel_matrix(basis)
@test isapprox(stack(basis.(nodeset)), A)
@test isapprox(A, I)
basis_functions = collect(basis)
for (i, b) in enumerate(basis)
@test b.(nodeset) == basis_functions[i].(nodeset)
end
# Test for Theorem 11.1 in Wendland's book
stdbasis = StandardBasis(nodeset, kernel)
R(x) = stdbasis(x)
function S(x)
v = zeros(length(basis.ps))
for i in eachindex(v)
v[i] = basis.ps[i](basis.xx => x)
end
return v
end
b(x) = [R(x); S(x)]
K = KernelInterpolation.interpolation_matrix(stdbasis, basis.ps)
x = rand(dim(nodeset))
uv = K \ b(x)
u = basis(x)
# Difficult to test for v
@test isapprox(u, uv[1:length(u)])
end

@testitem "Interpolation" setup=[Setup, AdditionalImports] begin
Expand Down Expand Up @@ -750,6 +781,29 @@ end
@test isapprox(itp([0.5, 0.5]), 1.0)
@test isapprox(kernel_norm(itp), 0.0, atol = 1e-15)

# Least squares with LagrangeBasis (not really recommended because you still need to solve a linear system)
basis = LagrangeBasis(centers, kernel)
basis_functions = collect(basis)
# There is no RBF part
for b in basis_functions
@test isapprox(kernel_coefficients(b), zeros(length(centers)))
end
itp = interpolate(basis, ff, nodes)
coeffs = coefficients(itp)
# Polynomial coefficients add up correctly
expected_coefficients = [
0.0,
1.0,
1.0]
for i in eachindex(expected_coefficients)
coeff = 0.0
for (j, b) in enumerate(basis_functions)
coeff += coeffs[j] * polynomial_coefficients(b)[i]
end
@test isapprox(coeff, expected_coefficients[i], atol = 1e-15)
end
@test isapprox(itp([0.5, 0.5]), 1.0)

# 1D interpolation and evaluation
nodes = NodeSet(LinRange(0.0, 1.0, 10))
f(x) = sinpi(x[1])
Expand Down

0 comments on commit 366e9b8

Please sign in to comment.