diff --git a/Project.toml b/Project.toml index 849f5be..8e8570d 100644 --- a/Project.toml +++ b/Project.toml @@ -16,32 +16,36 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [compat] ADTypes = "1" ChainRulesCore = "1" -DifferentiationInterface = "0.5.3" +DifferentiationInterface = "0.6.13" Distributions = "0.25" FiniteDiff = "2" ForwardDiff = "0.10" HCubature = "1" LinearAlgebra = "1" -Optimization = "3.25" -OptimizationOptimJL = "0.3" +Optimization = "4.0.3" +OptimizationOptimJL = "0.4.1" Reexport = "1" ReverseDiff = "1" SparseArrays = "1" SparseConnectivityTracer = "0.5, 0.6" +SparseMatrixColorings = "0.4.7" +StableRNGs = "1.0.2" Zygote = "0.6" -julia = "1.9" +julia = "1.10" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "Distributions", "Random", "ReverseDiff", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "Distributions", "Random", "ReverseDiff", "StableRNGs", "Test", "Zygote"] diff --git a/src/MarginalLogDensities.jl b/src/MarginalLogDensities.jl index f51669f..fd077c2 100644 --- a/src/MarginalLogDensities.jl +++ b/src/MarginalLogDensities.jl @@ -7,6 +7,7 @@ import ForwardDiff, FiniteDiff @reexport using DifferentiationInterface @reexport using ADTypes @reexport using SparseConnectivityTracer +@reexport using SparseMatrixColorings using LinearAlgebra using SparseArrays using ChainRulesCore @@ -25,9 +26,8 @@ export MarginalLogDensity, cached_hessian, merge_parameters, split_parameters, - optimize_marginal!, - # hessdiag, - get_hessian_sparsity + optimize_marginal! + # hessdiag abstract type AbstractMarginalizer end @@ -172,7 +172,7 @@ struct MarginalLogDensity{ TC<:OptimizationCache, TH<:AbstractMatrix, TB<:ADTypes.AbstractADType, - TE<:DifferentiationInterface.HessianExtras + TE } logdensity::TF u::TU @@ -185,7 +185,7 @@ struct MarginalLogDensity{ cache::TC H::TH hess_adtype::TB - hess_extras::TE + hess_prep::TE end @@ -204,15 +204,16 @@ function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox(); if isnothing(hess_adtype) hess_adtype = AutoSparse( + # TODO: AutoForwardDiff would be much faster SecondOrder(AutoFiniteDiff(), method.adtype), sparsity_detector, coloring_algorithm ) end - extras = prepare_hessian(w -> f(w, p2), hess_adtype, w) - H = hessian(w -> f(w, p2), hess_adtype, w, extras) + prep = prepare_hessian(f, hess_adtype, w, Constant(p2)) + H = hessian(f, prep, hess_adtype, w, Constant(p2)) return MarginalLogDensity(logdensity, u, data, iv, iw, method, f_opt, prob, cache, - H, hess_adtype, extras) + H, hess_adtype, prep) end function Base.show(io::IO, mld::MarginalLogDensity) @@ -285,7 +286,7 @@ function optimize_marginal!(mld, p2) end function modal_hessian!(mld::MarginalLogDensity, w, p2) - hessian!(w -> mld.f_opt(w, p2), mld.H, mld.hess_adtype, w, mld.hess_extras) + hessian!(mld.f_opt, mld.H, mld.hess_prep, mld.hess_adtype, w, Constant(p2)) return mld.H end diff --git a/test/runtests.jl b/test/runtests.jl index 76cc54f..99f2402 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,10 +5,10 @@ using Optimization, OptimizationOptimJL using FiniteDiff, ForwardDiff, ReverseDiff, Zygote using LinearAlgebra, SparseArrays using HCubature -using Random +using StableRNGs using ChainRulesTestUtils -Random.seed!(15950) +rng = StableRNG(15) N = 3 μ = ones(N) @@ -18,7 +18,7 @@ ld(u, p) = logpdf(d, u) iw = [1, 3] iv = [2] dmarginal = Normal(1.0, σ) -u = randn(N) +u = randn(rng, N) v = u[iv] w = u[iw] @@ -90,7 +90,7 @@ end mld_cubature2 = MarginalLogDensity(ld, u, iw, (), Cubature()) @test -mld_laplace.f_opt(x[iw], (p=(), v=x[iv])) == ld(x, ()) - prob = OptimizationProblem(mld_laplace.f_opt, randn(2), (p=(), v=x[iv])) + prob = OptimizationProblem(mld_laplace.f_opt, randn(rng, 2), (p=(), v=x[iv])) sol = solve(prob, BFGS()) @test all(sol.u .≈ μ[iw]) @@ -148,7 +148,7 @@ end ld(u, p) = logpdf(MvNormal(p.μ, p.σ * I), u) iv = 50:60 iw = setdiff(1:N, iv) - u = randn(N) + u = randn(rng, N) v = u[iv] w = u[iw] p = (;μ, σ) @@ -174,12 +174,12 @@ end categories = 1:ncategories μ0 = 5.0 σ0 = 5.0 - aa = rand(Normal(μ0, σ0), ncategories) + aa = rand(rng, Normal(μ0, σ0), ncategories) b = 4.5 σ = 0.5 category = repeat(categories, inner=200) n = length(category) - x = rand(Uniform(-1, 1), n) + x = rand(rng, Uniform(-1, 1), n) μ = [aa[category[i]] + b * x[i] for i in 1:n] y = rand.(Normal.(μ, σ)) @@ -212,7 +212,7 @@ end @test all(isapprox.(opt_sol1.u, opt_sol2.u)) opt_sol1_1 = solve(opt_prob1, LBFGS()) - @test all(isapprox.(opt_sol1.u, opt_sol1_1.u, atol=0.01)) + @test all(isapprox.(opt_sol1.u, opt_sol1_1.u, rtol=0.01)) # opt_prob3 = OptimizationProblem(mld_cubature, v0) # opt_sol3 = solve(opt_prob3, NelderMead())