Skip to content

Commit

Permalink
Upgrades all around (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 5, 2023
1 parent bca1bab commit c1f18c7
Show file tree
Hide file tree
Showing 20 changed files with 334 additions and 183 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/Manifest.toml
/docs/build/
**/.CondaPkg
**/scratchpad.jl
scratchpad.jl
**/__pycache__
*.ipynb
*.svg
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"

[extensions]
HiddenMarkovModelsChainRulesCoreExt = "ChainRulesCore"
HiddenMarkovModelsHMMBaseExt = "HMMBase"

[compat]
Expand Down
5 changes: 1 addition & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ HMMs

## Types

```@docs
HMMs.AbstractModel
```

### Markov chains

```@docs
Expand Down Expand Up @@ -43,6 +39,7 @@ obs_distribution

```@docs
logdensityof
forward
viterbi
forward_backward
```
Expand Down
42 changes: 42 additions & 0 deletions ext/HiddenMarkovModelsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module HiddenMarkovModelsChainRulesCoreExt

using ChainRulesCore:
ChainRulesCore, NoTangent, ZeroTangent, RuleConfig, rrule_via_ad, @not_implemented
using DensityInterface: logdensityof
using HiddenMarkovModels
using SimpleUnPack

function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq)
p = initial_distribution(hmm)
A = transition_matrix(hmm)
logB = HiddenMarkovModels.loglikelihoods(hmm, obs_seq)
return p, A, logB
end

function ChainRulesCore.rrule(
rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq
)
error("Chain rule not yet fully implemented")
(p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq)
y = exp.(logB)
fb = forward_backward(p, A, logB)
logL = HiddenMarkovModels.loglikelihood(fb)
@unpack α, β, γ, c, maxlogB = fb
T = length(obs_seq)

function logdensityof_hmm_pullback(ΔlogL)
# Source: https://idiap.github.io/HMMGradients.jl/stable/1_intro/
# TODO: adapt formulas with our logsumexp trick
Δp = @not_implemented("todo")
ΔA = @not_implemented("todo")
ΔlogB = ΔlogL .* fb.γ

Δlogdensityof = NoTangent()
_, Δhmm, Δobs_seq = pullback((Δp, ΔA, ΔlogB))
return Δlogdensityof, Δhmm, Δobs_seq
end

return logL, logdensityof_hmm_pullback
end

end
25 changes: 13 additions & 12 deletions src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Alias for the module HiddenMarkovModels.
const HMMs = HiddenMarkovModels

using Base.Threads: @threads
using ChainRulesCore: ChainRulesCore
using DensityInterface:
DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof
using Distributions:
Expand All @@ -38,12 +37,10 @@ export AbstractHiddenMarkovModel, AbstractHMM
export HiddenMarkovModel, HMM
export rand_prob_vec, rand_trans_mat
export initial_distribution, transition_matrix, obs_distribution
export logdensityof, viterbi, forward_backward, baum_welch
export logdensityof, viterbi, forward, forward_backward, baum_welch
export fit, fit!
export LightDiagNormal

include("types/abstract_model.jl")
include("types/interface.jl")
include("types/abstract_mc.jl")
include("types/mc.jl")
include("types/abstract_hmm.jl")
Expand All @@ -56,10 +53,9 @@ include("utils/fit.jl")
include("utils/lightdiagnormal.jl")

include("inference/loglikelihoods.jl")
include("inference/forward_backward_storage.jl")
include("inference/forward_backward.jl")
include("inference/logdensity.jl")
include("inference/forward.jl")
include("inference/viterbi.jl")
include("inference/forward_backward.jl")
include("inference/sufficient_stats.jl")
include("inference/baum_welch.jl")

Expand All @@ -68,6 +64,9 @@ if !isdefined(Base, :get_extension)
@require HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" include(
"../ext/HiddenMarkovModelsHMMBaseExt.jl"
)
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/HiddenMarkovModelsChainRulesCoreExt.jl"
)
end
end

Expand All @@ -78,11 +77,13 @@ end
dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]
hmm = HMM(p, A, dists)

@unpack state_seq, obs_seq = rand(hmm, T)
logdensityof(hmm, obs_seq)
viterbi(hmm, obs_seq)
forward_backward(hmm, obs_seq)
baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf)
obs_seqs = [last(rand(hmm, T)) for _ in 1:3]
nb_seqs = 3
logdensityof(hmm, obs_seqs, nb_seqs)
forward(hmm, obs_seqs, nb_seqs)
viterbi(hmm, obs_seqs, nb_seqs)
forward_backward(hmm, obs_seqs, nb_seqs)
baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf)
end

end
22 changes: 17 additions & 5 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function baum_welch!(
fbs = Vector{typeof(fb)}(undef, length(obs_seqs))
@threads for k in eachindex(obs_seqs)
logBs[k] = loglikelihoods(hmm, obs_seqs[k])
fbs[k] = forward_backward_from_loglikelihoods(hmm, logBs[k])
fbs[k] = forward_backward(hmm, logBs[k])
end

init_count, trans_count = initialize_states_stats(fbs)
Expand Down Expand Up @@ -54,9 +54,15 @@ end
atol, max_iterations, check_loglikelihood_increasing
)
Apply the Baum-Welch algorithm to estimate the parameters of an HMM and return a tuple `(hmm, logL_evolution)`.
Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`.
The procedure is based on a single observation sequence and initialized with `hmm_init`.
Return a tuple `(hmm_est, logL_evolution)`.
# Keyword arguments
- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged)
- `max_iterations`: Maximum number of iterations of the algorithm
- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases
"""
function baum_welch(
hmm_init::AbstractHMM,
Expand All @@ -78,12 +84,18 @@ end
atol, max_iterations, check_loglikelihood_increasing
)
Apply the Baum-Welch algorithm to estimate the parameters of an HMM and return a tuple `(hmm, logL_evolution)`.
Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on `nb_seqs` observation sequences.
The procedure is based on multiple observation sequences and initialized with `hmm_init`.
Return a tuple `(hmm_est, logL_evolution)`.
!!! warning "Multithreading"
This function is parallelized across sequences.
# Keyword arguments
- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged)
- `max_iterations`: Maximum number of iterations of the algorithm
- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases
"""
function baum_welch(
hmm_init::AbstractHMM,
Expand Down
57 changes: 51 additions & 6 deletions src/inference/logdensity.jl → src/inference/forward.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function forward_light!(αₜ, αₜ₊₁, logb, p, A, hmm::AbstractHMM, obs_seq)
function forward!(αₜ, αₜ₊₁, logb, p, A, hmm::AbstractHMM, obs_seq)
T = length(obs_seq)
loglikelihoods_vec!(logb, hmm, obs_seq[1])
m = maximum(logb)
Expand All @@ -20,11 +20,16 @@ function forward_light!(αₜ, αₜ₊₁, logb, p, A, hmm::AbstractHMM, obs_se
end

"""
logdensityof(hmm, obs_seq)
forward(hmm, obs_seq)
Apply the forward algorithm to compute the loglikelihood of a single observation sequence for an HMM.
Apply the forward algorithm to an HMM.
Return a tuple `(α, logL)` where
- `logL` is the loglikelihood of the sequence
- `α[i]` is the posterior probability of state `i` at the end of the sequence.
"""
function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq)
function forward(hmm::AbstractHMM, obs_seq)
N = length(hmm)
p = initial_distribution(hmm)
A = transition_matrix(hmm)
Expand All @@ -33,15 +38,54 @@ function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq)
R = promote_type(eltype(p), eltype(A), eltype(logb))
αₜ = Vector{R}(undef, N)
αₜ₊₁ = Vector{R}(undef, N)
logL = forward_light!(αₜ, αₜ₊₁, logb, p, A, hmm, obs_seq)
return logL
logL = forward!(αₜ, αₜ₊₁, logb, p, A, hmm, obs_seq)
return αₜ, logL
end

"""
forward(hmm, obs_seqs, nb_seqs)
Apply the forward algorithm to an HMM, based on multiple observation sequences.
Return a vector of tuples `(αₖ, logLₖ)`, where
- `logLₖ` is the loglikelihood of sequence `k`
- `αₖ[i]` is the posterior probability of state `i` at the end of sequence `k`
!!! warning "Multithreading"
This function is parallelized across sequences.
"""
function forward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer)
if nb_seqs != length(obs_seqs)
throw(ArgumentError("nb_seqs != length(obs_seqs)"))
end
f1 = forward(hmm, first(obs_seqs))
fs = Vector{typeof(f1)}(undef, nb_seqs)
fs[1] = f1
@threads for k in 2:nb_seqs
fs[k] = forward(hmm, obs_seqs[k])
end
return fs
end

"""
logdensityof(hmm, obs_seq)
Apply the forward algorithm to compute the loglikelihood of a single observation sequence for an HMM.
Return a number.
"""
function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq)
return last(forward(hmm, obs_seq))
end

"""
logdensityof(hmm, obs_seqs, nb_seqs)
Apply the forward algorithm to compute the total loglikelihood of multiple observation sequences for an HMM.
Return a number.
!!! warning "Multithreading"
This function is parallelized across sequences.
"""
Expand All @@ -51,6 +95,7 @@ function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seqs, nb_seqs::Inte
end
logL1 = logdensityof(hmm, first(obs_seqs))
logLs = Vector{typeof(logL1)}(undef, nb_seqs)
logLs[1] = logL1
@threads for k in 2:nb_seqs
logLs[k] = logdensityof(hmm, obs_seqs[k])
end
Expand Down
Loading

0 comments on commit c1f18c7

Please sign in to comment.