Skip to content

Commit

Permalink
Merge pull request #853 from sathvikbhagavan/patch-1
Browse files Browse the repository at this point in the history
fix: kernel functions
  • Loading branch information
ChrisRackauckas authored Sep 19, 2023
2 parents 7c8f3e4 + 41d1031 commit 7e3cb5f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 40 deletions.
27 changes: 16 additions & 11 deletions src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ function calckernel(::TriangularKernel,t)
end

function calckernel(::QuarticKernel,t)
if abs(t)>0
if abs(t) > 1
return 0
else
return (15*(1-t^2)^2)/16
end
end

function calckernel(::TriweightKernel,t)
if abs(t)>0
if abs(t) > 1
return 0
else
return (35*(1-t^2)^3)/32
end
end

function calckernel(::TricubeKernel,t)
if abs(t)>0
if abs(t) > 1
return 0
else
return (70*(1-abs(t)^3)^3)/80
return (70*(1-abs(t)^3)^3)/81
end
end

Expand All @@ -64,7 +64,7 @@ function calckernel(::GaussianKernel,t)
end

function calckernel(::CosineKernel,t)
if abs(t)>0
if abs(t) > 1
return 0
else
return*cos*t/2))/4
Expand Down Expand Up @@ -92,14 +92,14 @@ function construct_t2(t,tpoints)
end

function construct_w(t,tpoints,h,kernel)
W = @. calckernel((kernel,),(tpoints-t)/h)/h
W = @. calckernel((kernel,),((tpoints-t)/(tpoints[end]-tpoints[begin]))/h)/h
Diagonal(W)
end


"""
```julia
u′,u = collocate_data(data,tpoints,kernel=SigmoidKernel())
u′,u = collocate_data(data,tpoints,kernel=TriangularKernel(),bandwidth=nothing)
u′,u = collocate_data(data,tpoints,tpoints_sample,interp,args...)
```
Expand Down Expand Up @@ -128,24 +128,29 @@ Additionally, we can use interpolation methods from
data from intermediate timesteps. In this case, pass any of the methods like
`QuadraticInterpolation` as `interp`, and the timestamps to sample from as `tpoints_sample`.
"""
function collocate_data(data,tpoints,kernel=TriangularKernel())
function collocate_data(data, tpoints, kernel=TriangularKernel(), bandwidth=nothing)
_one = oneunit(first(data))
_zero = zero(first(data))
e1 = [_one;_zero]
e2 = [_zero;_one;_zero]
n = length(tpoints)
h = (n^(-1/5))*(n^(-3/35))*((log(n))^(-1/16))
bandwidth = isnothing(bandwidth) ? (n^(-1/5))*(n^(-3/35))*((log(n))^(-1/16)) : bandwidth

Wd = similar(data, n, size(data,1))
WT1 = similar(data, n, 2)
WT2 = similar(data, n, 3)
T2WT2 = similar(data, 3, 3)
T1WT1 = similar(data, 2, 2)
x = map(tpoints) do _t
T1 = construct_t1(_t,tpoints)
T2 = construct_t2(_t,tpoints)
W = construct_w(_t,tpoints,h,kernel)
W = construct_w(_t,tpoints,bandwidth,kernel)
mul!(Wd,W,data')
mul!(WT1,W,T1)
mul!(WT2,W,T2)
mul!(T2WT2,T2',WT2)
mul!(T1WT1,T1',WT1)
(det(T2WT2) 0.0 || det(T1WT1) 0.0) && error("Collocation failed with bandwidth $bandwidth. Please choose a higher bandwidth")
(e2'*((T2'*WT2)\T2'))*Wd,(e1'*((T1'*WT1)\T1'))*Wd
end
estimated_derivative = reduce(hcat,transpose.(first.(x)))
Expand All @@ -163,7 +168,7 @@ function collocate_data(data::AbstractMatrix{T},tpoints::AbstractVector{T},
tpoints_sample::AbstractVector{T},interp,args...) where T
u = zeros(T,size(data, 1),length(tpoints_sample))
du = zeros(T,size(data, 1),length(tpoints_sample))
for d1 in 1:size(data,1)
for d1 in axes(data, 1)
interpolation = interp(data[d1,:],tpoints,args...)
u[d1,:] .= interpolation.(tpoints_sample)
du[d1,:] .= DataInterpolations.derivative.((interpolation,), tpoints_sample)
Expand Down
67 changes: 67 additions & 0 deletions test/collocation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using DiffEqFlux, OrdinaryDiffEq, Test

bounded_support_kernels = [
EpanechnikovKernel(),
UniformKernel(),
TriangularKernel(),
QuarticKernel(),
TriweightKernel(),
TricubeKernel(),
CosineKernel(),
]

unbounded_support_kernels =
[GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()]

@testset "Kernel Functions" begin
ts = collect(-5.0:0.1:5.0)
@testset "Kernels with support from -1 to 1" begin
minus_one_index = findfirst(x -> ==(x, -1.0), ts)
plus_one_index = findfirst(x -> ==(x, 1.0), ts)
@testset "$kernel" for (kernel, x0) in zip(
bounded_support_kernels,
[0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0],
)
ws = DiffEqFlux.calckernel.((kernel,), ts)
# t < -1
@test all(ws[1:minus_one_index-1] .== 0.0)
# t > 1
@test all(ws[plus_one_index+1:end] .== 0.0)
# -1 < t <1
@test all(ws[minus_one_index+1:plus_one_index-1] .> 0.0)
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
@testset "Kernels with unbounded support" begin
@testset "$kernel" for (kernel, x0) in zip(
unbounded_support_kernels,
[1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))],
)
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
end

@testset "Collocation of data" begin
function f(u, p, t)
p .* u
end
rc = 2
ps = repeat([-0.001], rc)
tspan = (0.0, 50.0)
u0 = 3.4 .+ ones(rc)
t = collect(range(minimum(tspan), stop = maximum(tspan), length = 1000))
prob = ODEProblem(f, u0, tspan, ps)
data = Array(solve(prob, Tsit5(), saveat = t, abstol = 1e-12, reltol = 1e-12))
@testset "$kernel" for kernel in
[bounded_support_kernels..., unbounded_support_kernels...]
u′, u = collocate_data(data, t, kernel, 0.003)
@test sum(abs2, u - data) < 1e-8
end
@testset "$kernel" for kernel in [bounded_support_kernels...]
# Errors out as the bandwidth is too low
@test_throws ErrorException collocate_data(data, t, kernel, 0.001)
end
end
27 changes: 0 additions & 27 deletions test/collocation_regression.jl

This file was deleted.

4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ const is_CI = haskey(ENV, "CI")

@time begin
if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "Layers"
@safetestset "Collocation Regression" begin
include("collocation_regression.jl")
@safetestset "Collocation" begin
include("collocation.jl")
end
@safetestset "Stiff Nested AD Tests" begin
include("stiff_nested_ad.jl")
Expand Down

0 comments on commit 7e3cb5f

Please sign in to comment.