From 193fda3ccf70681b22a6f6c95741f11a16abe4d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 May 2024 18:07:00 -0400 Subject: [PATCH] Update FFJORD to use Lux AD calls --- .buildkite/pipeline.yml | 2 +- Project.toml | 12 ++- README.md | 2 +- docs/pages.jl | 34 ++++-- docs/src/index.md | 2 +- src/DiffEqFlux.jl | 6 +- src/ffjord.jl | 180 +++++++++++++------------------- test/{cnf_t.jl => cnf_tests.jl} | 29 +++-- test/mnist_tests.jl | 4 +- 9 files changed, 130 insertions(+), 141 deletions(-) rename test/{cnf_t.jl => cnf_tests.jl} (87%) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d57099843..0969c8a4a 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -12,7 +12,7 @@ steps: # Don't run Buildkite if the commit message includes the text [skip tests] if: build.message !~ /\[skip tests\]/ env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 1 # These tests require quite a lot of GPU memory GROUP: CUDA DATADEPS_ALWAYS_ACCEPT: 'true' JULIA_PKG_SERVER: "" # it often struggles with our large artifacts diff --git a/Project.toml b/Project.toml index 147920264..8956f794e 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "3.4.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -23,6 +22,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -34,7 +34,7 @@ Aqua = "0.8.7" BenchmarkTools = "1.5.0" CUDA = "5.3.4" ChainRulesCore = "1" -ComponentArrays = "0.15.5" +ComponentArrays = "0.15.12" ConcreteStructs = "0.2" DataInterpolations = "5.0.0" DelayDiffEq = "5.47.3" @@ -44,6 +44,7 @@ Distances = "0.10.11" Distributed = "1.10" Distributions = "0.25" DistributionsAD = "0.6" +ExplicitImports = "1.4.4" Flux = "0.14.15" ForwardDiff = "0.10" Functors = "0.4" @@ -71,8 +72,9 @@ ReverseDiff = "1.15.3" SafeTestsets = "0.1.0" SciMLBase = "1, 2" SciMLSensitivity = "7" +Setfield = "1.1.1" StaticArrays = "1.9.4" -Statistics = "1.11.1" +Statistics = "1.10" StochasticDiffEq = "6.65.1" Test = "1.10" Tracker = "0.2.29" @@ -84,11 +86,13 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" @@ -112,4 +116,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BenchmarkTools", "CUDA", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "Flux", "LuxCUDA", "MLDataUtils", "MLDatasets", "NLopt", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "ReTestItems", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "StaticArrays", "Statistics", "StochasticDiffEq", "Test"] +test = ["Aqua", "BenchmarkTools", "CUDA", "ComponentArrays", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "Flux", "LuxCUDA", "MLDataUtils", "MLDatasets", "NLopt", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "ReTestItems", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "StaticArrays", "Statistics", "StochasticDiffEq", "Test"] diff --git a/README.md b/README.md index 351bc663b..e36f115be 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ by helping users put diffeq solvers into neural networks. This package utilizes [Scientific Machine Learning](https://www.stochasticlifestyle.com/the-essential-tools-of-scientific-machine-learning-scientific-ml/), specifically neural differential equations to add physical information into traditional machine learning. > [!NOTE] -> We maintain backwards compatibility with [Flux.jl](https://docs.sciml.ai/Flux/stable/) via [FromFluxAdaptor()](https://lux.csail.mit.edu/stable/api/Lux/flux_to_lux#FromFluxAdaptor()) +> We maintain backwards compatibility with [Flux.jl](https://docs.sciml.ai/Flux/stable/) via [FromFluxAdaptor()](https://lux.csail.mit.edu/stable/api/Lux/interop#Lux.FromFluxAdaptor) ## Tutorials and Documentation diff --git a/docs/pages.jl b/docs/pages.jl index 36b8854cc..04bfaa208 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,17 +1,31 @@ +#! format: off pages = [ "DiffEqFlux.jl: High Level Scientific Machine Learning (SciML) Pre-Built Architectures" => "index.md", "Differential Equation Machine Learning Tutorials" => Any[ - "examples/neural_ode.md", "examples/GPUs.md", - "examples/mnist_neural_ode.md", "examples/mnist_conv_neural_ode.md", - "examples/augmented_neural_ode.md", "examples/neural_sde.md", - "examples/collocation.md", "examples/normalizing_flows.md", - "examples/hamiltonian_nn.md", "examples/tensor_layer.md", - "examples/multiple_shooting.md", "examples//neural_ode_weather_forecast.md"], - "Layer APIs" => Any["Classical Basis Layers" => "layers/BasisLayers.md", + "examples/neural_ode.md", + "examples/GPUs.md", + "examples/mnist_neural_ode.md", + "examples/mnist_conv_neural_ode.md", + "examples/augmented_neural_ode.md", + "examples/neural_sde.md", + "examples/collocation.md", + "examples/normalizing_flows.md", + "examples/hamiltonian_nn.md", + "examples/tensor_layer.md", + "examples/multiple_shooting.md", + "examples/neural_ode_weather_forecast.md" + ], + "Layer APIs" => Any[ + "Classical Basis Layers" => "layers/BasisLayers.md", "Tensor Product Layer" => "layers/TensorLayer.md", "Continuous Normalizing Flows Layer" => "layers/CNFLayer.md", "Spline Layer" => "layers/SplineLayer.md", "Neural Differential Equation Layers" => "layers/NeuralDELayers.md", - "Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md"], - "Utility Function APIs" => Any["Smoothed Collocation" => "utilities/Collocation.md", - "Multiple Shooting Functionality" => "utilities/MultipleShooting.md"]] + "Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md" + ], + "Utility Function APIs" => Any[ + "Smoothed Collocation" => "utilities/Collocation.md", + "Multiple Shooting Functionality" => "utilities/MultipleShooting.md" + ] +] +#! format: on diff --git a/docs/src/index.md b/docs/src/index.md index 70745390d..d1d5f027f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -19,7 +19,7 @@ The approach of this package is the easy and efficient training of [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) and its variants. DiffEqFlux.jl provides architectures which match the interfaces of machine learning libraries such as [Flux.jl](https://docs.sciml.ai/Flux/stable/) -and [Lux.jl](https://lux.csail.mit.edu/stable/api/) +and [Lux.jl](https://lux.csail.mit.edu/stable/) to make it easy to build continuous-time machine learning layers into larger machine learning applications. diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index 7d2eee5ba..434f24524 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -5,15 +5,14 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ADTypes: ADTypes, AutoForwardDiff, AutoZygote using ChainRulesCore: ChainRulesCore - using ComponentArrays: ComponentArray using ConcreteStructs: @concrete using Distributions: Distributions, ContinuousMultivariateDistribution, Distribution, logpdf using DistributionsAD: DistributionsAD using ForwardDiff: ForwardDiff using Functors: Functors, fmap - using LinearAlgebra: LinearAlgebra, Diagonal, det, diagind, mul! - using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor + using LinearAlgebra: LinearAlgebra, Diagonal, det, tr, mul! + using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor, ⊠ using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer using Random: Random, AbstractRNG, randn! using Reexport: @reexport @@ -26,6 +25,7 @@ using PrecompileTools: @recompile_invalidations NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP, SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint, ZygoteVJP + using Setfield: @set using Tracker: Tracker using Zygote: Zygote end diff --git a/src/ffjord.jl b/src/ffjord.jl index b78b156b1..eab5e1b8b 100644 --- a/src/ffjord.jl +++ b/src/ffjord.jl @@ -1,31 +1,37 @@ abstract type CNFLayer <: LuxCore.AbstractExplicitContainerLayer{(:model,)} end """ - FFJORD(model, tspan, input_dims, args...; ad = AutoForwardDiff(), - basedist = nothing, kwargs...) + FFJORD(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) -Constructs a continuous-time recurrent neural network, also known as a neural -ordinary differential equation (neural ODE), with fast gradient calculation -via adjoints [1] and specialized for density estimation based on continuous -normalizing flows (CNF) [2] with a stochastic approach [2] for the computation of the trace -of the dynamics' jacobian. At a high level this corresponds to the following steps: +Constructs a continuous-time recurrent neural network, also known as a neural ordinary +differential equation (neural ODE), with fast gradient calculation via adjoints [1] and +specialized for density estimation based on continuous normalizing flows (CNF) [2] with a +stochastic approach [2] for the computation of the trace of the dynamics' jacobian. At a +high level this corresponds to the following steps: - 1. Parameterize the variable of interest x(t) as a function f(z, θ, t) of a base variable z(t) with known density p\\_z. - 2. Use the transformation of variables formula to predict the density p\\_x as a function of the density p\\_z and the trace of the Jacobian of f. - 3. Choose the parameter θ to minimize a loss function of p\\_x (usually the negative likelihood of the data). + 1. Parameterize the variable of interest x(t) as a function f(z, θ, t) of a base variable + z(t) with known density p\\_z. + 2. Use the transformation of variables formula to predict the density p\\_x as a function + of the density p\\_z and the trace of the Jacobian of f. + 3. Choose the parameter θ to minimize a loss function of p\\_x (usually the negative + likelihood of the data). -After these steps one may use the NN model and the learned θ to predict the density p\\_x for new values of x. +After these steps one may use the NN model and the learned θ to predict the density p\\_x +for new values of x. Arguments: - - `model`: A `Flux.Chain` or `Lux.AbstractExplicitLayer` neural network that defines the dynamics of the model. + - `model`: A `Flux.Chain` or `Lux.AbstractExplicitLayer` neural network that defines the + dynamics of the model. - `basedist`: Distribution of the base variable. Set to the unit normal by default. - `input_dims`: Input Dimensions of the model. - `tspan`: The timespan to be solved on. - `args`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - - `ad`: The automatic differentiation method to use for the internal jacobian trace. Defaults to `AutoForwardDiff()`. + - `ad`: The automatic differentiation method to use for the internal jacobian trace. + Defaults to `AutoForwardDiff()` if full jacobian needs to be computed, i.e. + `monte_carlo = false`. Else we use `AutoZygote()`. - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. @@ -34,9 +40,13 @@ References: [1] Pontryagin, Lev Semenovich. Mathematical theory of optimal processes. CRC press, 1987. -[2] Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural ordinary differential equations." In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp. 6572-6583. 2018. +[2] Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural ordinary +differential equations." In Proceedings of the 32nd International Conference on Neural +Information Processing Systems, pp. 6572-6583. 2018. -[3] Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018). +[3] Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. +"Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv +preprint arXiv:1810.01367 (2018). """ @concrete struct FFJORD{M <: AbstractExplicitLayer, D <: Union{Nothing, Distribution}} <: CNFLayer @@ -54,91 +64,55 @@ function LuxCore.initialstates(rng::AbstractRNG, n::FFJORD) regularize = false, monte_carlo = true) end -function FFJORD(model, tspan, input_dims, args...; - ad = AutoForwardDiff(), basedist = nothing, kwargs...) +function FFJORD( + model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) !(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model)) return FFJORD(model, basedist, ad, input_dims, tspan, args, kwargs) end -function __jacobian_with_ps(model, psax, N, x) - function __jacobian_closure(psx) - x_ = reshape(psx[1:N], size(x)) - ps = ComponentArray(psx[(N + 1):end], psax) - return vec(model(x_, ps)) - end -end - -function __jacobian( - ::AutoForwardDiff{nothing}, model, x::AbstractMatrix, ps::ComponentArray) - psd = getdata(ps) - psx = vcat(vec(x), psd) - N = length(x) - J = ForwardDiff.jacobian(__jacobian_with_ps(model, getaxes(ps), N, x), psx) - return reshape(view(J, :, 1:N), :, size(x, 1), size(x, 2)) -end - -function __jacobian(::AutoForwardDiff{CS}, model, x::AbstractMatrix, ps) where {CS} - chunksize = CS === nothing ? ForwardDiff.pickchunksize(length(x)) : CS - __f = Base.Fix2(model, ps) - cfg = ForwardDiff.JacobianConfig(__f, x, ForwardDiff.Chunk{chunksize}()) - return reshape(ForwardDiff.jacobian(__f, x, cfg), :, size(x, 1), size(x, 2)) -end - -function __jacobian(::AutoZygote, model, x::AbstractMatrix, ps) - y, pb_f = Zygote.pullback(vec ∘ model, x, ps) - z = ChainRulesCore.@ignore_derivatives fill!(similar(y), __one(y)) - J = Zygote.Buffer(x, size(y, 1), size(x, 1), size(x, 2)) - for i in 1:size(y, 1) - ChainRulesCore.@ignore_derivatives z[i, :] .= __one(x) - J[i, :, :] = pb_f(z)[1] - ChainRulesCore.@ignore_derivatives z[i, :] .= __zero(x) - end - return copy(J) +@inline function __trace_batched(x::AbstractArray{T, 3}) where {T} + return mapreduce(tr, vcat, eachslice(x; dims = 3); init = similar(x, 0)) end -__one(::T) where {T <: Real} = one(T) -__one(x::T) where {T <: AbstractArray} = __one(first(x)) -__one(::Tracker.TrackedReal{T}) where {T <: Real} = one(T) - -__zero(::T) where {T <: Real} = zero(T) -__zero(x::T) where {T <: AbstractArray} = __zero(first(x)) -__zero(::Tracker.TrackedReal{T}) where {T <: Real} = zero(T) - -function _jacobian(ad, model, x, ps) - if ndims(x) == 1 - x_ = reshape(x, :, 1) - elseif ndims(x) > 2 - x_ = reshape(x, :, size(x, ndims(x))) - else - x_ = x - end - return __jacobian(ad, model, x_, ps) -end - -# This implementation constructs the final trace vector on the correct device -function __trace_batched(x::AbstractArray{T, 3}) where {T} - __diag(x) = reshape(@view(x[diagind(x)]), :, 1) - return sum(reduce(hcat, __diag.(eachslice(x; dims = 3))); dims = 1) -end +@inline __norm_batched(x) = sqrt.(sum(abs2, x; dims = 1:(ndims(x) - 1))) -__norm_batched(x) = sqrt.(sum(abs2, x; dims = 1:(ndims(x) - 1))) - -function __ffjord(model, u, p, ad = AutoForwardDiff(), - regularize::Bool = false, monte_carlo::Bool = true) - N = ndims(u) +function __ffjord(_model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing, + regularize::Bool = false, monte_carlo::Bool = true) where {T, N} L = size(u, N - 1) z = selectdim(u, N - 1, 1:(L - ifelse(regularize, 3, 1))) + model = @set(_model.ps=p) + mz = model(z, p) + @assert size(mz) == size(z) if monte_carlo - mz, pb_f = Zygote.pullback(model, z, p) - e = CRC.@ignore_derivatives randn!(similar(mz)) - eJ = first(pb_f(e)) - trace_jac = sum(eJ .* e; dims = 1:(N - 1)) - else - mz = model(z, p) - J = _jacobian(ad, model, z, p) - trace_jac = __trace_batched(J) + ad = ad === nothing ? AutoZygote() : ad e = CRC.@ignore_derivatives randn!(similar(mz)) - eJ = vec(e)' * reshape(J, size(J, 1), :) + if ad isa AutoForwardDiff + @assert !regularize "If `regularize = true`, then use `AutoZygote` instead." + Je = Lux.jacobian_vector_product(model, AutoForwardDiff(), z, e) + trace_jac = dropdims( + sum(reshape(e, 1, :, size(e, N)) ⊠ reshape(Je, :, 1, size(Je, N)); + dims = (1, 2)); + dims = (1, 2)) + elseif ad isa AutoZygote + eJ = Lux.vector_jacobian_product(model, AutoZygote(), z, e) + trace_jac = dropdims( + sum(reshape(eJ, 1, :, size(eJ, N)) ⊠ reshape(e, :, 1, size(e, N)); + dims = (1, 2)); + dims = (1, 2)) + else + error("`ad` must be `nothing` or `AutoForwardDiff` or `AutoZygote`.") + end + trace_jac = reshape(trace_jac, ntuple(i -> 1, N - 1)..., :) + else # We can use the batched jacobian since we only care about the trace + ad = ad === nothing ? AutoForwardDiff() : ad + if ad isa AutoForwardDiff || ad isa AutoZygote + J = Lux.batched_jacobian(model, ad, z) + trace_jac = reshape(__trace_batched(J), ntuple(i -> 1, N - 1)..., :) + e = CRC.@ignore_derivatives randn!(similar(mz)) + eJ = reshape(reshape(e, 1, :, size(e, N)) ⊠ J, size(z)) + else + error("`ad` must be `nothing` or `AutoForwardDiff` or `AutoZygote`.") + end end if regularize return cat(mz, -trace_jac, sum(abs2, mz; dims = 1:(N - 1)), @@ -150,8 +124,8 @@ end (n::FFJORD)(x, ps, st) = __forward_ffjord(n, x, ps, st) -function __forward_ffjord(n::FFJORD, x, ps, st) - N, S, T = ndims(x), size(x), eltype(x) +function __forward_ffjord(n::FFJORD, x::AbstractArray{T, N}, ps, st) where {T, N} + S = size(x) (; regularize, monte_carlo) = st sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) @@ -174,8 +148,7 @@ function __forward_ffjord(n::FFJORD, x, ps, st) if regularize λ₁ = selectdim(pred, N, (L - 1):(L - 1)) λ₂ = selectdim(pred, N, L:L) - else - # For Type Stability + else # For Type Stability λ₁ = λ₂ = delta_logp end @@ -186,7 +159,6 @@ function __forward_ffjord(n::FFJORD, x, ps, st) logpz = logpdf(n.basedist, z) end logpx = reshape(logpz, 1, S[N]) .- delta_logp - return (logpx, λ₁, λ₂), (; model = model.st, regularize, monte_carlo) end @@ -197,17 +169,10 @@ function __backward_ffjord(::Type{T1}, n::FFJORD, n_samples::Int, ps, st, rng) w px = n.basedist if px === nothing - if rng === nothing - x = randn(T1, (n.input_dims..., n_samples)) - else - x = randn(rng, T1, (n.input_dims..., n_samples)) - end + x = rng === nothing ? randn(T1, (n.input_dims..., n_samples)) : + randn(rng, T1, (n.input_dims..., n_samples)) else - if rng === nothing - x = rand(px, n_samples) - else - x = rand(rng, px, n_samples) - end + x = rng === nothing ? rand(px, n_samples) : rand(rng, px, n_samples) end N, S, T = ndims(x), size(x), eltype(x) @@ -249,13 +214,12 @@ end Base.length(d::FFJORDDistribution) = prod(d.model.input_dims) Base.eltype(d::FFJORDDistribution) = __eltype(d.ps) -__eltype(ps::ComponentArray) = __eltype(getdata(ps)) __eltype(x::AbstractArray) = eltype(x) -function __eltype(x::NamedTuple) +function __eltype(x) T = Ref(Bool) fmap(x) do x_ T[] = promote_type(T[], __eltype(x_)) - x_ + return x_ end return T[] end @@ -268,6 +232,6 @@ function Distributions._logpdf(d::FFJORDDistribution, x::AbstractArray) end function Distributions._rand!( rng::AbstractRNG, d::FFJORDDistribution, x::AbstractArray{<:Real}) - x[:] = __backward_ffjord(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng) + copyto!(x, __backward_ffjord(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng)) return x end diff --git a/test/cnf_t.jl b/test/cnf_tests.jl similarity index 87% rename from test/cnf_t.jl rename to test/cnf_tests.jl index 75276f3a0..322b1647d 100644 --- a/test/cnf_t.jl +++ b/test/cnf_tests.jl @@ -1,10 +1,13 @@ -using DiffEqFlux, Zygote, Distances, Distributions, DistributionsAD, Optimization, - LinearAlgebra, OrdinaryDiffEq, Random, Test, OptimizationOptimisers, Statistics, - ComponentArrays +@testsetup module CNFTestSetup + +using Reexport + +@reexport using Zygote, Distances, Distributions, DistributionsAD, Optimization, + LinearAlgebra, OrdinaryDiffEq, Random, Test, OptimizationOptimisers, + Statistics, ComponentArrays Random.seed!(1999) -## callback to be used by all tests function callback(adtype) return function (p, l) @info "[FFJORD $(nameof(typeof(adtype)))] Loss: $(l)" @@ -12,7 +15,11 @@ function callback(adtype) end end -@testset "Smoke test for FFJORD" begin +export callback + +end + +@testitem "Smoke test for FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) @@ -46,7 +53,7 @@ end end end -@testset "Smoke test for FFJORDDistribution (sampling & pdf)" begin +@testitem "Smoke test for FFJORDDistribution (sampling & pdf)" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) @@ -80,7 +87,7 @@ end @test !isnothing(rand(ffjord_d, 10)) end -@testset "Test for default base distribution and deterministic trace FFJORD" begin +@testitem "Test for default base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) @@ -114,7 +121,7 @@ end @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.9 end -@testset "Test for alternative base distribution and deterministic trace FFJORD" begin +@testitem "Test for alternative base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD( @@ -149,7 +156,7 @@ end @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end -@testset "Test for multivariate distribution and deterministic trace FFJORD" begin +@testitem "Test for multivariate distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(2, 2, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) @@ -186,12 +193,12 @@ end @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end -@testset "Test for default multivariate distribution and FFJORD with regularizers" begin +@testitem "Test for multivariate distribution and FFJORD with regularizers" setup=[CNFTestSetup] tags=[:advancedneuralde] begin nn = Chain(Dense(2, 2, tanh)) tspan = (0.0f0, 1.0f0) ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + ps = ComponentArray(ps) .* 0.001f0 regularize = true monte_carlo = true diff --git a/test/mnist_tests.jl b/test/mnist_tests.jl index baca766a4..5e9bfabfe 100644 --- a/test/mnist_tests.jl +++ b/test/mnist_tests.jl @@ -27,10 +27,10 @@ function loadmnist(batchsize = bs) # Process images into (H,W,C,BS) batches x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> gdev - x_train = batchview(x_train, batchsize) + x_train = batchview(x_train[:, :, :, 1:(10 * batchsize)], batchsize) # Onehot and batch the labels y_train = onehot(labels_raw) |> gdev - y_train = batchview(y_train, batchsize) + y_train = batchview(y_train[:, 1:(10 * batchsize)], batchsize) return x_train, y_train end