Skip to content

Commit

Permalink
refactor: DataInterpolations doesn't need to be a main dep
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 11, 2024
1 parent a409957 commit d933b0b
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ steps:
version: "1.10"
- JuliaCI/julia-test#v1:
coverage: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ jobs:
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src, ext
- uses: codecov/codecov-action@v4
with:
file: lcov.info
Expand Down
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Expand All @@ -21,6 +19,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"

[extensions]
DiffEqFluxDataInterpolationsExt = "DataInterpolations"

[compat]
ADTypes = "1.5"
Aqua = "0.8.7"
Expand All @@ -35,7 +39,6 @@ DiffEqCallbacks = "3.6.2"
Distances = "0.10.11"
Distributed = "1.10"
Distributions = "0.25"
DistributionsAD = "0.6"
ExplicitImports = "1.9"
Flux = "0.14.15"
ForwardDiff = "0.10"
Expand Down Expand Up @@ -71,6 +74,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down Expand Up @@ -99,4 +103,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "ForwardDiff", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "MLDatasets", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test", "Zygote"]
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "ForwardDiff", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "MLDatasets", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test", "Zygote"]
19 changes: 19 additions & 0 deletions ext/DiffEqFluxDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module DiffEqFluxDataInterpolationsExt

using DataInterpolations: DataInterpolations
using DiffEqFlux: DiffEqFlux

@views function DiffEqFlux.collocate_data(
data::AbstractMatrix{T}, tpoints::AbstractVector{T},
tpoints_sample::AbstractVector{T}, interp, args...) where {T}
u = zeros(T, size(data, 1), length(tpoints_sample))
du = zeros(T, size(data, 1), length(tpoints_sample))
for d1 in axes(data, 1)
interpolation = interp(data[d1, :], tpoints, args...)
u[d1, :] .= interpolation.(tpoints_sample)
du[d1, :] .= DataInterpolations.derivative.((interpolation,), tpoints_sample)
end
return du, u
end

end
2 changes: 0 additions & 2 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ module DiffEqFlux
using ADTypes: ADTypes, AutoForwardDiff, AutoZygote
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using DataInterpolations: DataInterpolations
using Distributions: Distributions, ContinuousMultivariateDistribution, Distribution, logpdf
using DistributionsAD: DistributionsAD
using LinearAlgebra: LinearAlgebra, Diagonal, det, tr, mul!
using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
Expand Down
12 changes: 0 additions & 12 deletions src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,3 @@ end
du, u = collocate_data(reshape(data, 1, :), tpoints, tpoints_sample, interp, args...)
return du[1, :], u[1, :]
end

@views function collocate_data(data::AbstractMatrix{T}, tpoints::AbstractVector{T},
tpoints_sample::AbstractVector{T}, interp, args...) where {T}
u = zeros(T, size(data, 1), length(tpoints_sample))
du = zeros(T, size(data, 1), length(tpoints_sample))
for d1 in axes(data, 1)
interpolation = interp(data[d1, :], tpoints, args...)
u[d1, :] .= interpolation.(tpoints_sample)
du[d1, :] .= DataInterpolations.derivative.((interpolation,), tpoints_sample)
end
return du, u
end

0 comments on commit d933b0b

Please sign in to comment.