Skip to content

Commit

Permalink
Update FFJORD to use Lux AD calls
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 19, 2024
1 parent 99f8a2e commit 31ff571
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 145 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ steps:
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/
env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKERS: 1 # These tests require quite a lot of GPU memory
GROUP: CUDA
DATADEPS_ALWAYS_ACCEPT: 'true'
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
Expand Down
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "3.4.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -23,6 +22,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -34,7 +34,7 @@ Aqua = "0.8.7"
BenchmarkTools = "1.5.0"
CUDA = "5.3.4"
ChainRulesCore = "1"
ComponentArrays = "0.15.5"
ComponentArrays = "0.15.12"
ConcreteStructs = "0.2"
DataInterpolations = "5.0.0"
DelayDiffEq = "5.47.3"
Expand All @@ -44,6 +44,7 @@ Distances = "0.10.11"
Distributed = "1.10"
Distributions = "0.25"
DistributionsAD = "0.6"
ExplicitImports = "1.4.4"
Flux = "0.14.15"
ForwardDiff = "0.10"
Functors = "0.4"
Expand Down Expand Up @@ -71,8 +72,9 @@ ReverseDiff = "1.15.3"
SafeTestsets = "0.1.0"
SciMLBase = "1, 2"
SciMLSensitivity = "7"
Setfield = "1.1.1"
StaticArrays = "1.9.4"
Statistics = "1.11.1"
Statistics = "1.10"
StochasticDiffEq = "6.65.1"
Test = "1.10"
Tracker = "0.2.29"
Expand All @@ -84,11 +86,13 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
Expand All @@ -112,4 +116,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "CUDA", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "Flux", "LuxCUDA", "MLDataUtils", "MLDatasets", "NLopt", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "ReTestItems", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "StaticArrays", "Statistics", "StochasticDiffEq", "Test"]
test = ["Aqua", "BenchmarkTools", "CUDA", "ComponentArrays", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "Flux", "LuxCUDA", "MLDataUtils", "MLDatasets", "NLopt", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "ReTestItems", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "StaticArrays", "Statistics", "StochasticDiffEq", "Test"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ by helping users put diffeq solvers into neural networks. This package utilizes
[Scientific Machine Learning](https://www.stochasticlifestyle.com/the-essential-tools-of-scientific-machine-learning-scientific-ml/), specifically neural differential equations to add physical information into traditional machine learning.

> [!NOTE]
> We maintain backwards compatibility with [Flux.jl](https://docs.sciml.ai/Flux/stable/) via [FromFluxAdaptor()](https://lux.csail.mit.edu/stable/api/Lux/flux_to_lux#FromFluxAdaptor())
> We maintain backwards compatibility with [Flux.jl](https://docs.sciml.ai/Flux/stable/) via [FromFluxAdaptor()](https://lux.csail.mit.edu/stable/api/Lux/interop#Lux.FromFluxAdaptor)
## Tutorials and Documentation

Expand Down
34 changes: 24 additions & 10 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
#! format: off
pages = [
"DiffEqFlux.jl: High Level Scientific Machine Learning (SciML) Pre-Built Architectures" => "index.md",
"Differential Equation Machine Learning Tutorials" => Any[
"examples/neural_ode.md", "examples/GPUs.md",
"examples/mnist_neural_ode.md", "examples/mnist_conv_neural_ode.md",
"examples/augmented_neural_ode.md", "examples/neural_sde.md",
"examples/collocation.md", "examples/normalizing_flows.md",
"examples/hamiltonian_nn.md", "examples/tensor_layer.md",
"examples/multiple_shooting.md", "examples//neural_ode_weather_forecast.md"],
"Layer APIs" => Any["Classical Basis Layers" => "layers/BasisLayers.md",
"examples/neural_ode.md",
"examples/GPUs.md",
"examples/mnist_neural_ode.md",
"examples/mnist_conv_neural_ode.md",
"examples/augmented_neural_ode.md",
"examples/neural_sde.md",
"examples/collocation.md",
"examples/normalizing_flows.md",
"examples/hamiltonian_nn.md",
"examples/tensor_layer.md",
"examples/multiple_shooting.md",
"examples/neural_ode_weather_forecast.md"
],
"Layer APIs" => Any[
"Classical Basis Layers" => "layers/BasisLayers.md",
"Tensor Product Layer" => "layers/TensorLayer.md",
"Continuous Normalizing Flows Layer" => "layers/CNFLayer.md",
"Spline Layer" => "layers/SplineLayer.md",
"Neural Differential Equation Layers" => "layers/NeuralDELayers.md",
"Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md"],
"Utility Function APIs" => Any["Smoothed Collocation" => "utilities/Collocation.md",
"Multiple Shooting Functionality" => "utilities/MultipleShooting.md"]]
"Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md"
],
"Utility Function APIs" => Any[
"Smoothed Collocation" => "utilities/Collocation.md",
"Multiple Shooting Functionality" => "utilities/MultipleShooting.md"
]
]
#! format: on
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The approach of this package is the easy and efficient training of
[Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) and its variants.
DiffEqFlux.jl provides architectures which match the interfaces of
machine learning libraries such as [Flux.jl](https://docs.sciml.ai/Flux/stable/)
and [Lux.jl](https://lux.csail.mit.edu/stable/api/)
and [Lux.jl](https://lux.csail.mit.edu/stable/)
to make it easy to build continuous-time machine learning layers
into larger machine learning applications.

Expand Down
6 changes: 3 additions & 3 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ using PrecompileTools: @recompile_invalidations
@recompile_invalidations begin
using ADTypes: ADTypes, AutoForwardDiff, AutoZygote
using ChainRulesCore: ChainRulesCore
using ComponentArrays: ComponentArray
using ConcreteStructs: @concrete
using Distributions: Distributions, ContinuousMultivariateDistribution, Distribution,
logpdf
using DistributionsAD: DistributionsAD
using ForwardDiff: ForwardDiff
using Functors: Functors, fmap
using LinearAlgebra: LinearAlgebra, Diagonal, det, diagind, mul!
using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor
using LinearAlgebra: LinearAlgebra, Diagonal, det, tr, mul!
using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor,
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Random: Random, AbstractRNG, randn!
using Reexport: @reexport
Expand All @@ -26,6 +25,7 @@ using PrecompileTools: @recompile_invalidations
NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP,
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
ZygoteVJP
using Setfield: @set
using Tracker: Tracker
using Zygote: Zygote
end
Expand Down
Loading

0 comments on commit 31ff571

Please sign in to comment.