From 865947ddcad757a16d66a98cdd334845c1a6920e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Oct 2024 14:53:17 +0200 Subject: [PATCH] Remove last allocations --- libs/HMMTest/src/allocations.jl | 3 ++- src/types/hmm.jl | 24 +++++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index 6ba7d74..8deedaf 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -6,6 +6,7 @@ function test_allocations( seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) + # making seq_ends a tuple disables multithreading seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends)))) control_seq = control_seq[1:last(seq_ends)] @@ -51,7 +52,7 @@ function test_allocations( allocs_bw = @ballocated fit!( hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends ) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess)) - @test_broken allocs_bw == 0 + @test allocs_bw == 0 end end end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 60d7bf1..9c1f47b 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -61,13 +61,23 @@ function StatsAPI.fit!( ) (; γ, ξ) = fb_storage # Fit states - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - # use ξ[t2] as scratch space since it is zero anyway - scratch = ξ[t2] - fill!(scratch, zero(eltype(scratch))) - for t in t1:(t2 - 1) - scratch .+= ξ[t] + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway + fill!(scratch, zero(eltype(scratch))) + for t in t1:(t2 - 1) + scratch .+= ξ[t] + end + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway + fill!(scratch, zero(eltype(scratch))) + for t in t1:(t2 - 1) + scratch .+= ξ[t] + end end end fill!(hmm.init, zero(eltype(hmm.init)))