Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 1, 2024
1 parent ad7b618 commit 670be12
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
9 changes: 6 additions & 3 deletions libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ function test_allocations(
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends))))
control_seq = control_seq[1:last(seq_ends)]

@testset "Allocations" begin
obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k
t1, t2 = seq_limits(seq_ends, k)
Expand All @@ -18,23 +21,23 @@ function test_allocations(

f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends)
allocs_f = @ballocated HMMs.forward!(
$f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_f == 0

## Viterbi

v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
allocs_v = @ballocated HMMs.viterbi!(
$v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_v == 0

## Forward-backward

fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
allocs_fb = @ballocated HMMs.forward_backward!(
$fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_fb == 0

Expand Down
9 changes: 3 additions & 6 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ function initialize_forward(
return ForwardStorage(α, logL, B, c)
end

"""
$(SIGNATURES)
"""
function forward!(
function _forward!(
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
Expand Down Expand Up @@ -124,11 +121,11 @@ function forward!(
)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)

Check warning on line 124 in src/inference/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward.jl#L124

Added line #L124 was not covered by tests
end
else
@threads for k in eachindex(seq_ends)
forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
11 changes: 4 additions & 7 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ function initialize_forward_backward(
return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ)
end

"""
$(SIGNATURES)
"""
function forward_backward!(
function _forward_backward!(
storage::ForwardBackwardStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
Expand All @@ -45,7 +42,7 @@ function forward_backward!(
t1, t2 = seq_limits(seq_ends, k)

# Forward (fill B, α, c and logL)
forward!(storage, hmm, obs_seq, control_seq, t1, t2)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)

# Backward
β[:, t2] .= c[t2]
Expand Down Expand Up @@ -85,13 +82,13 @@ function forward_backward!(
)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
forward_backward!(
_forward_backward!(

Check warning on line 85 in src/inference/forward_backward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward_backward.jl#L85

Added line #L85 was not covered by tests
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
forward_backward!(
_forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
Expand Down
9 changes: 3 additions & 6 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ function initialize_viterbi(
return ViterbiStorage(q, logL, logB, ϕ, ψ)
end

"""
$(SIGNATURES)
"""
function viterbi!(
function _viterbi!(
storage::ViterbiStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
Expand Down Expand Up @@ -88,11 +85,11 @@ function viterbi!(
) where {R}
if seq_ends isa NTuple
for k in eachindex(seq_ends)
viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)

Check warning on line 88 in src/inference/viterbi.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/viterbi.jl#L88

Added line #L88 was not covered by tests
end
else
@threads for k in eachindex(seq_ends)
viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down

0 comments on commit 670be12

Please sign in to comment.