Skip to content

Commit

Permalink
Remove last allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 1, 2024
1 parent 670be12 commit 865947d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
3 changes: 2 additions & 1 deletion libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ function test_allocations(
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
# making seq_ends a tuple disables multithreading
seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends))))
control_seq = control_seq[1:last(seq_ends)]

Expand Down Expand Up @@ -51,7 +52,7 @@ function test_allocations(
allocs_bw = @ballocated fit!(
hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess))
@test_broken allocs_bw == 0
@test allocs_bw == 0
end
end
end
24 changes: 17 additions & 7 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@ function StatsAPI.fit!(
)
(; γ, ξ) = fb_storage
# Fit states
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
# use ξ[t2] as scratch space since it is zero anyway
scratch = ξ[t2]
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end

Check warning on line 72 in src/types/hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/hmm.jl#L65-L72

Added lines #L65 - L72 were not covered by tests
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end
end
fill!(hmm.init, zero(eltype(hmm.init)))
Expand Down

0 comments on commit 865947d

Please sign in to comment.