Skip to content

Commit

Permalink
Fix Aqua tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2023
1 parent cf55104 commit f2bd8cd
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 23 deletions.
26 changes: 13 additions & 13 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0-rc1"
manifest_format = "2.0"
project_hash = "33d5f46deb395b3db16d766ef39217d71a80867a"
project_hash = "461882012e6e993286bb5accc7e9d45749a2da45"

[[deps.ADTypes]]
git-tree-sha1 = "332e5d7baeff8497b923b730b994fa480601efc7"
Expand Down Expand Up @@ -250,9 +250,9 @@ version = "1.9.1"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"]
git-tree-sha1 = "de4709e30bd5490435122c4b415b90a812c23fbf"
git-tree-sha1 = "5e365e0744ae1fdd44e62416343e7bfe848999d8"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.138.1"
version = "6.139.0"

[deps.DiffEqBase.extensions]
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -641,9 +641,9 @@ version = "6.4.0"

[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "a84f8f1e8caaaa4e3b4c101306b9e801d3883ace"
git-tree-sha1 = "98eaee04d96d973e79c25d49167668c5c8fb50e2"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.27+0"
version = "0.0.27+1"

[[deps.LatticeRules]]
deps = ["Random"]
Expand Down Expand Up @@ -1128,9 +1128,9 @@ version = "1.3.4"

[[deps.RecursiveArrayTools]]
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7"
git-tree-sha1 = "d7087c013e8a496ff396bae843b1e16d9a30ede8"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "2.39.0"
version = "2.38.10"

[deps.RecursiveArrayTools.extensions]
RecursiveArrayToolsMeasurementsExt = "Measurements"
Expand Down Expand Up @@ -1208,9 +1208,9 @@ version = "0.6.42"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "QuasiMonteCarlo", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"]
git-tree-sha1 = "dd2d18b981d09a2376ba49c5fab480f497992c88"
git-tree-sha1 = "baa0f858af55ea937183c988bd4b3e79e0bf699a"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
version = "2.8.0"
version = "2.8.1"

[deps.SciMLBase.extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -1309,9 +1309,9 @@ version = "1.10.0"

[[deps.SparseDiffTools]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
git-tree-sha1 = "e162b74fd1ce6d371ff5c584b53e34538edb9212"
git-tree-sha1 = "49068dceed7febe32afe2d7b874172a7c3198cb3"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
version = "2.11.0"
version = "2.12.0"

[deps.SparseDiffTools.extensions]
SparseDiffToolsEnzymeExt = "Enzyme"
Expand Down Expand Up @@ -1484,11 +1484,11 @@ version = "0.5.23"

[[deps.Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
git-tree-sha1 = "8e895c98c27a4203a8061dd8c10b010269e8d3a6"
git-tree-sha1 = "752daa5bbd9721b0566e39cdc75cffdc3ef5593d"
repo-rev = "ap/ambiguous"
repo-url = "https://github.com/avik-pal/Tracker.jl"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.29"
version = "0.2.30"
weakdeps = ["PDMats"]

[deps.Tracker.extensions]
Expand Down
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -31,18 +28,26 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ADTypes = "0.2"
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.15.5"
ConcreteStructs = "0.2"
DiffEqBase = "6.41"
Distributions = "0.23, 0.24, 0.25"
DistributionsAD = "0.6"
ForwardDiff = "0.10"
Functors = "0.4"
LinearAlgebra = "<0.0.1, 1"
Lux = "0.5.5"
LuxCore = "0.1"
PrecompileTools = "1"
Random = "<0.0.1, 1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1, 2"
SciMLSensitivity = "7"
Tracker = "0.2.30"
Zygote = "0.5, 0.6"
ZygoteRules = "0.2"
julia = "1.9"
9 changes: 8 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import ChainRulesCore as CRC
import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
import Lux.Experimental: StatefulLuxLayer

@reexport using ADTypes, Lux, SciMLSensitivity
@reexport using ADTypes, Lux

# FIXME: Type Piracy
function CRC.rrule(::Type{Tridiagonal}, dl, d, du)
Expand Down Expand Up @@ -52,4 +52,11 @@ export collocate_data

export multiple_shoot

# Reexporting only certain functions from SciMLSensitivity
export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint,
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, ForwardSensitivity,
ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint,
ForwardLSS, AdjointLSS, NILSS, NILSAS
export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP

end
18 changes: 14 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const is_CI = haskey(ENV, "CI")
end
end

if GROUP == "Newton"
if GROUP == "All" || GROUP == "Newton"
@safetestset "Newton Neural ODE Tests" begin
include("newton_neural_ode.jl")
end
Expand All @@ -69,9 +69,19 @@ const is_CI = haskey(ENV, "CI")
end
end

@safetestset "Aqua Q/A" begin
using Aqua, DiffEqFlux
if GROUP == "All" || GROUP == "Aqua"
@safetestset "Aqua Q/A" begin
using Aqua, DiffEqFlux

Aqua.test_all(DiffEqFlux; ambiguities = false)
# TODO: Enable persistent tasks once the downstream PRs are merged
Aqua.test_all(DiffEqFlux; ambiguities = false, piracies = false,
persistent_tasks = false)

Aqua.test_ambiguities(DiffEqFlux; recursive = false)

# FIXME: Remove Tridiagonal piracy after
# https://github.com/JuliaDiff/ChainRules.jl/issues/713 is merged!
Aqua.test_piracies(DiffEqFlux; treat_as_own = [LinearAlgebra.Tridiagonal])
end
end
end
4 changes: 2 additions & 2 deletions test/spline_layer_test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DiffEqFlux, ComponentArrays,
Zygote, DataInterpolations, Distributions, Optimization, LinearAlgebra, Random, Test
using DiffEqFlux, ComponentArrays, Zygote, DataInterpolations, Distributions, Optimization,
OptimizationOptimisers, LinearAlgebra, Random, Test

function run_test(f, layer, atol)
ps, st = Lux.setup(Xoshiro(0), layer)
Expand Down

0 comments on commit f2bd8cd

Please sign in to comment.