Skip to content

Commit

Permalink
Merge pull request #31 from gdalle/gd/di06
Browse files Browse the repository at this point in the history
Transition to DI v0.6
  • Loading branch information
ElOceanografo authored Oct 15, 2024
2 parents 64d3b3e + 1ae0a77 commit 9304bfb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
14 changes: 9 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,36 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
ADTypes = "1"
ChainRulesCore = "1"
DifferentiationInterface = "0.5.3"
DifferentiationInterface = "0.6.13"
Distributions = "0.25"
FiniteDiff = "2"
ForwardDiff = "0.10"
HCubature = "1"
LinearAlgebra = "1"
Optimization = "3.25"
OptimizationOptimJL = "0.3"
Optimization = "4.0.3"
OptimizationOptimJL = "0.4.1"
Reexport = "1"
ReverseDiff = "1"
SparseArrays = "1"
SparseConnectivityTracer = "0.5, 0.6"
SparseMatrixColorings = "0.4.7"
StableRNGs = "1.0.2"
Zygote = "0.6"
julia = "1.9"
julia = "1.10"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "Distributions", "Random", "ReverseDiff", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "Distributions", "Random", "ReverseDiff", "StableRNGs", "Test", "Zygote"]
19 changes: 10 additions & 9 deletions src/MarginalLogDensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ForwardDiff, FiniteDiff
@reexport using DifferentiationInterface
@reexport using ADTypes
@reexport using SparseConnectivityTracer
@reexport using SparseMatrixColorings
using LinearAlgebra
using SparseArrays
using ChainRulesCore
Expand All @@ -25,9 +26,8 @@ export MarginalLogDensity,
cached_hessian,
merge_parameters,
split_parameters,
optimize_marginal!,
# hessdiag,
get_hessian_sparsity
optimize_marginal!
# hessdiag

abstract type AbstractMarginalizer end

Expand Down Expand Up @@ -172,7 +172,7 @@ struct MarginalLogDensity{
TC<:OptimizationCache,
TH<:AbstractMatrix,
TB<:ADTypes.AbstractADType,
TE<:DifferentiationInterface.HessianExtras
TE
}
logdensity::TF
u::TU
Expand All @@ -185,7 +185,7 @@ struct MarginalLogDensity{
cache::TC
H::TH
hess_adtype::TB
hess_extras::TE
hess_prep::TE
end


Expand All @@ -204,15 +204,16 @@ function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox();

if isnothing(hess_adtype)
hess_adtype = AutoSparse(
# TODO: AutoForwardDiff would be much faster
SecondOrder(AutoFiniteDiff(), method.adtype),
sparsity_detector,
coloring_algorithm
)
end
extras = prepare_hessian(w -> f(w, p2), hess_adtype, w)
H = hessian(w -> f(w, p2), hess_adtype, w, extras)
prep = prepare_hessian(f, hess_adtype, w, Constant(p2))
H = hessian(f, prep, hess_adtype, w, Constant(p2))
return MarginalLogDensity(logdensity, u, data, iv, iw, method, f_opt, prob, cache,
H, hess_adtype, extras)
H, hess_adtype, prep)
end

function Base.show(io::IO, mld::MarginalLogDensity)
Expand Down Expand Up @@ -285,7 +286,7 @@ function optimize_marginal!(mld, p2)
end

function modal_hessian!(mld::MarginalLogDensity, w, p2)
hessian!(w -> mld.f_opt(w, p2), mld.H, mld.hess_adtype, w, mld.hess_extras)
hessian!(mld.f_opt, mld.H, mld.hess_prep, mld.hess_adtype, w, Constant(p2))
return mld.H
end

Expand Down
16 changes: 8 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ using Optimization, OptimizationOptimJL
using FiniteDiff, ForwardDiff, ReverseDiff, Zygote
using LinearAlgebra, SparseArrays
using HCubature
using Random
using StableRNGs
using ChainRulesTestUtils

Random.seed!(15950)
rng = StableRNG(15)

N = 3
μ = ones(N)
Expand All @@ -18,7 +18,7 @@ ld(u, p) = logpdf(d, u)
iw = [1, 3]
iv = [2]
dmarginal = Normal(1.0, σ)
u = randn(N)
u = randn(rng, N)
v = u[iv]
w = u[iw]

Expand Down Expand Up @@ -90,7 +90,7 @@ end
mld_cubature2 = MarginalLogDensity(ld, u, iw, (), Cubature())

@test -mld_laplace.f_opt(x[iw], (p=(), v=x[iv])) == ld(x, ())
prob = OptimizationProblem(mld_laplace.f_opt, randn(2), (p=(), v=x[iv]))
prob = OptimizationProblem(mld_laplace.f_opt, randn(rng, 2), (p=(), v=x[iv]))
sol = solve(prob, BFGS())
@test all(sol.u .≈ μ[iw])

Expand Down Expand Up @@ -148,7 +148,7 @@ end
ld(u, p) = logpdf(MvNormal(p.μ, p.σ * I), u)
iv = 50:60
iw = setdiff(1:N, iv)
u = randn(N)
u = randn(rng, N)
v = u[iv]
w = u[iw]
p = (;μ, σ)
Expand All @@ -174,12 +174,12 @@ end
categories = 1:ncategories
μ0 = 5.0
σ0 = 5.0
aa = rand(Normal(μ0, σ0), ncategories)
aa = rand(rng, Normal(μ0, σ0), ncategories)
b = 4.5
σ = 0.5
category = repeat(categories, inner=200)
n = length(category)
x = rand(Uniform(-1, 1), n)
x = rand(rng, Uniform(-1, 1), n)
μ = [aa[category[i]] + b * x[i] for i in 1:n]
y = rand.(Normal.(μ, σ))

Expand Down Expand Up @@ -212,7 +212,7 @@ end
@test all(isapprox.(opt_sol1.u, opt_sol2.u))

opt_sol1_1 = solve(opt_prob1, LBFGS())
@test all(isapprox.(opt_sol1.u, opt_sol1_1.u, atol=0.01))
@test all(isapprox.(opt_sol1.u, opt_sol1_1.u, rtol=0.01))

# opt_prob3 = OptimizationProblem(mld_cubature, v0)
# opt_sol3 = solve(opt_prob3, NelderMead())
Expand Down

0 comments on commit 9304bfb

Please sign in to comment.