diff --git a/Manifest.toml b/Manifest.toml index ae83632e6..a72ab9797 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.0-rc1" manifest_format = "2.0" -project_hash = "33d5f46deb395b3db16d766ef39217d71a80867a" +project_hash = "461882012e6e993286bb5accc7e9d45749a2da45" [[deps.ADTypes]] git-tree-sha1 = "332e5d7baeff8497b923b730b994fa480601efc7" @@ -250,9 +250,9 @@ version = "1.9.1" [[deps.DiffEqBase]] deps = ["ArrayInterface", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"] -git-tree-sha1 = "de4709e30bd5490435122c4b415b90a812c23fbf" +git-tree-sha1 = "5e365e0744ae1fdd44e62416343e7bfe848999d8" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.138.1" +version = "6.139.0" [deps.DiffEqBase.extensions] DiffEqBaseChainRulesCoreExt = "ChainRulesCore" @@ -641,9 +641,9 @@ version = "6.4.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "a84f8f1e8caaaa4e3b4c101306b9e801d3883ace" +git-tree-sha1 = "98eaee04d96d973e79c25d49167668c5c8fb50e2" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.27+0" +version = "0.0.27+1" [[deps.LatticeRules]] deps = ["Random"] @@ -1128,9 +1128,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7" +git-tree-sha1 = "d7087c013e8a496ff396bae843b1e16d9a30ede8" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "2.39.0" +version = "2.38.10" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsMeasurementsExt = "Measurements" @@ -1208,9 +1208,9 @@ version = "0.6.42" [[deps.SciMLBase]] deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "QuasiMonteCarlo", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"] -git-tree-sha1 = "dd2d18b981d09a2376ba49c5fab480f497992c88" +git-tree-sha1 = "baa0f858af55ea937183c988bd4b3e79e0bf699a" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.8.0" +version = "2.8.1" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -1309,9 +1309,9 @@ version = "1.10.0" [[deps.SparseDiffTools]] deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"] -git-tree-sha1 = "e162b74fd1ce6d371ff5c584b53e34538edb9212" +git-tree-sha1 = "49068dceed7febe32afe2d7b874172a7c3198cb3" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -version = "2.11.0" +version = "2.12.0" [deps.SparseDiffTools.extensions] SparseDiffToolsEnzymeExt = "Enzyme" @@ -1484,11 +1484,11 @@ version = "0.5.23" [[deps.Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "8e895c98c27a4203a8061dd8c10b010269e8d3a6" +git-tree-sha1 = "752daa5bbd9721b0566e39cdc75cffdc3ef5593d" repo-rev = "ap/ambiguous" repo-url = "https://github.com/avik-pal/Tracker.jl" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.29" +version = "0.2.30" weakdeps = ["PDMats"] [deps.Tracker.extensions] diff --git a/Project.toml b/Project.toml index c63b168d5..df1c7770d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,15 +12,12 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -31,18 +28,26 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ADTypes = "0.2" Adapt = "3" ChainRulesCore = "1" +ComponentArrays = "0.15.5" +ConcreteStructs = "0.2" DiffEqBase = "6.41" Distributions = "0.23, 0.24, 0.25" DistributionsAD = "0.6" ForwardDiff = "0.10" Functors = "0.4" +LinearAlgebra = "<0.0.1, 1" +Lux = "0.5.5" LuxCore = "0.1" +PrecompileTools = "1" +Random = "<0.0.1, 1" RecursiveArrayTools = "2" Reexport = "0.2, 1" SciMLBase = "1, 2" SciMLSensitivity = "7" +Tracker = "0.2.30" Zygote = "0.5, 0.6" ZygoteRules = "0.2" julia = "1.9" diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index e623f6a23..5c895d657 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -17,7 +17,7 @@ import ChainRulesCore as CRC import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer import Lux.Experimental: StatefulLuxLayer -@reexport using ADTypes, Lux, SciMLSensitivity +@reexport using ADTypes, Lux # FIXME: Type Piracy function CRC.rrule(::Type{Tridiagonal}, dl, d, du) @@ -52,4 +52,11 @@ export collocate_data export multiple_shoot +# Reexporting only certain functions from SciMLSensitivity +export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint, + TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, ForwardSensitivity, + ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, + ForwardLSS, AdjointLSS, NILSS, NILSAS +export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP + end diff --git a/test/runtests.jl b/test/runtests.jl index 2a09a5e20..64dd2b813 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,7 +51,7 @@ const is_CI = haskey(ENV, "CI") end end - if GROUP == "Newton" + if GROUP == "All" || GROUP == "Newton" @safetestset "Newton Neural ODE Tests" begin include("newton_neural_ode.jl") end @@ -69,9 +69,19 @@ const is_CI = haskey(ENV, "CI") end end - @safetestset "Aqua Q/A" begin - using Aqua, DiffEqFlux + if GROUP == "All" || GROUP == "Aqua" + @safetestset "Aqua Q/A" begin + using Aqua, DiffEqFlux - Aqua.test_all(DiffEqFlux; ambiguities = false) + # TODO: Enable persistent tasks once the downstream PRs are merged + Aqua.test_all(DiffEqFlux; ambiguities = false, piracies = false, + persistent_tasks = false) + + Aqua.test_ambiguities(DiffEqFlux; recursive = false) + + # FIXME: Remove Tridiagonal piracy after + # https://github.com/JuliaDiff/ChainRules.jl/issues/713 is merged! + Aqua.test_piracies(DiffEqFlux; treat_as_own = [LinearAlgebra.Tridiagonal]) + end end end diff --git a/test/spline_layer_test.jl b/test/spline_layer_test.jl index c2ea25fa8..535c0fab5 100644 --- a/test/spline_layer_test.jl +++ b/test/spline_layer_test.jl @@ -1,5 +1,5 @@ -using DiffEqFlux, ComponentArrays, - Zygote, DataInterpolations, Distributions, Optimization, LinearAlgebra, Random, Test +using DiffEqFlux, ComponentArrays, Zygote, DataInterpolations, Distributions, Optimization, + OptimizationOptimisers, LinearAlgebra, Random, Test function run_test(f, layer, atol) ps, st = Lux.setup(Xoshiro(0), layer)