Skip to content

Commit

Permalink
Disable multithreading when seq_ends is passed as a tuple (#113)
Browse files Browse the repository at this point in the history
* Fix loglikelihood increase check in Baum-Welch

* Disable multithreading when `seq_ends` is given as a tuple

* Remove seq_ends typing in examples
  • Loading branch information
gdalle authored Sep 29, 2024
1 parent 3276e3c commit 9802bad
Show file tree
Hide file tree
Showing 16 changed files with 93 additions and 64 deletions.
2 changes: 1 addition & 1 deletion examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
Expand Down
2 changes: 1 addition & 1 deletion examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function StatsAPI.fit!(
hmm::PriorHMM,
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
obs_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
)
## initialize to defaults without observations
hmm.init .= 0
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)
Expand Down
1 change: 1 addition & 0 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module HMMTest

using BenchmarkTools: @ballocated
using HiddenMarkovModels
using HiddenMarkovModels: AbstractVectorOrNTuple
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using JET: @test_opt, @test_call
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_allocations(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Allocations" begin
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function test_coherent_algorithms(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
atol::Real=0.05,
init::Bool=true,
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_type_stability(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Type stability" begin
Expand Down
6 changes: 3 additions & 3 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function baum_welch!(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
atol::Real,
max_iterations::Integer,
loglikelihood_increasing::Bool,
Expand Down Expand Up @@ -55,7 +55,7 @@ function baum_welch(
hmm_guess::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
atol=1e-5,
max_iterations=100,
loglikelihood_increasing=true,
Expand Down Expand Up @@ -85,7 +85,7 @@ function StatsAPI.fit!(
fb_storage::ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
return fit!(hmm, fb_storage, obs_seq; seq_ends)
end
4 changes: 2 additions & 2 deletions src/inference/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _params_and_loglikelihoods(
hmm::AbstractHMM,
obs_seq::Vector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
init = initialization(hmm)
trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t
Expand All @@ -22,7 +22,7 @@ function ChainRulesCore.rrule(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
_, pullback = rrule_via_ad(
rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends
Expand Down
51 changes: 43 additions & 8 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,35 @@ struct ForwardStorage{R}
c::Vector{R}
end

"""
$(TYPEDEF)
# Fields
Only the fields with a description are part of the public API.
$(TYPEDFIELDS)
"""
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
"posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
γ::Matrix{R}
"posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
ξ::Vector{M}
"one loglikelihood per observation sequence"
logL::Vector{R}
B::Matrix{R}
α::Matrix{R}
c::Vector{R}
β::Matrix{R}
::Matrix{R}
end

Base.eltype(::ForwardStorage{R}) where {R} = R
Base.eltype(::ForwardBackwardStorage{R}) where {R} = R

const ForwardOrForwardBackwardStorage{R} = Union{
ForwardStorage{R},ForwardBackwardStorage{R}
}

"""
$(SIGNATURES)
Expand All @@ -25,7 +53,7 @@ function initialize_forward(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
Expand All @@ -40,7 +68,7 @@ end
$(SIGNATURES)
"""
function forward!(
storage,
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
Expand Down Expand Up @@ -88,16 +116,23 @@ end
$(SIGNATURES)
"""
function forward!(
storage,
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
end
return nothing
end
Expand All @@ -113,7 +148,7 @@ function forward(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends)
forward!(storage, hmm, obs_seq, control_seq; seq_ends)
Expand Down
54 changes: 19 additions & 35 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
"""
$(TYPEDEF)
# Fields
Only the fields with a description are part of the public API.
$(TYPEDFIELDS)
"""
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
"posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
γ::Matrix{R}
"posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
ξ::Vector{M}
"one loglikelihood per observation sequence"
logL::Vector{R}
B::Matrix{R}
α::Matrix{R}
c::Vector{R}
β::Matrix{R}
::Matrix{R}
end

Base.eltype(::ForwardBackwardStorage{R}) where {R} = R

"""
$(SIGNATURES)
"""
function initialize_forward_backward(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals=true,
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
Expand Down Expand Up @@ -100,19 +75,28 @@ end
$(SIGNATURES)
"""
function forward_backward!(
storage::ForwardBackwardStorage{R},
storage::ForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals::Bool=true,
) where {R}
)
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
end
end
return nothing
end
Expand All @@ -128,7 +112,7 @@ function forward_backward(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
transition_marginals = false
storage = initialize_forward_backward(
Expand Down
4 changes: 2 additions & 2 deletions src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function DensityInterface.logdensityof(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
_, logL = forward(hmm, obs_seq, control_seq; seq_ends)
return sum(logL)
Expand All @@ -23,7 +23,7 @@ function joint_logdensityof(
obs_seq::AbstractVector,
state_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
R = eltype(hmm, obs_seq[1], control_seq[1])
logL = zero(R)
Expand Down
19 changes: 13 additions & 6 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function initialize_viterbi(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
Expand Down Expand Up @@ -85,12 +85,19 @@ function viterbi!(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
) where {R}
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
end
return nothing
end
Expand All @@ -106,7 +113,7 @@ function viterbi(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
viterbi!(storage, hmm, obs_seq, control_seq; seq_ends)
Expand Down
2 changes: 1 addition & 1 deletion src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function StatsAPI.fit!(
hmm::HMM,
fb_storage::ForwardBackwardStorage,
obs_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
(; γ, ξ) = fb_storage
# Fit states
Expand Down
2 changes: 1 addition & 1 deletion src/utils/limits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(SIGNATURES)
Return a tuple `(t1, t2)` giving the begin and end indices of subsequence `k` within a set of sequences ending at `seq_ends`.
"""
function seq_limits(seq_ends::AbstractVector{Int}, k::Integer)
function seq_limits(seq_ends::AbstractVectorOrNTuple{Int}, k::Integer)
if k == 1
return 1, seq_ends[k]
else
Expand Down
2 changes: 2 additions & 0 deletions src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const AbstractVectorOrNTuple{T} = Union{AbstractVector{T},NTuple{N,T}} where {N}

sum_to_one!(x) = ldiv!(sum(x), x)

mynonzeros(x::AbstractArray) = x
Expand Down

0 comments on commit 9802bad

Please sign in to comment.