Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 29, 2024
1 parent 27b6f3e commit 1c6ee9a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function test_coherent_algorithms(

if !isnothing(hmm_guess)
hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
@show diff(logL_evolution)
@test all(>=(0), diff(logL_evolution))
test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true)
test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init)
Expand Down
2 changes: 1 addition & 1 deletion src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function baum_welch(
seq_ends,
atol,
max_iterations,
loglikelihood_increasing=false,
loglikelihood_increasing,
)
return hmm, logL_evolution
end
Expand Down
14 changes: 7 additions & 7 deletions test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ seq_ends = cumsum(length.(control_seqs));

## Uncontrolled

@testset "Normal" begin
@testset verbose = true "Normal" begin
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

Expand All @@ -51,7 +51,7 @@ seq_ends = cumsum(length.(control_seqs));
end
end

@testset "DiagNormal" begin
@testset verbose = true "DiagNormal" begin
dists = [MvNormal(μ[1], Diagonal(abs2.(σ))), MvNormal(μ[2], Diagonal(abs2.(σ)))]
dists_guess = [
MvNormal(μ_guess[1], Diagonal(abs2.(σ))), MvNormal(μ_guess[2], Diagonal(abs2.(σ)))
Expand All @@ -68,7 +68,7 @@ end
end
end

@testset "LightCategorical" begin
@testset verbose = true "LightCategorical" begin
dists = [LightCategorical(p[1]), LightCategorical(p[2])]
dists_guess = [LightCategorical(p_guess[1]), LightCategorical(p_guess[2])]

Expand All @@ -82,7 +82,7 @@ end
end
end

@testset "LightDiagNormal" begin
@testset verbose = true "LightDiagNormal" begin
dists = [LightDiagNormal(μ[1], σ), LightDiagNormal(μ[2], σ)]
dists_guess = [LightDiagNormal(μ_guess[1], σ), LightDiagNormal(μ_guess[2], σ)]

Expand All @@ -96,7 +96,7 @@ end
end
end

@testset "Normal (sparse)" begin
@testset verbose = true "Normal (sparse)" begin
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

Expand All @@ -112,7 +112,7 @@ end
end
end

@testset "Normal transposed" begin # issue 99
@testset verbose = true "Normal transposed" begin # issue 99
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

Expand All @@ -128,7 +128,7 @@ end
end
end

@testset "Normal and Exponential" begin # issue 101
@testset verbose = true "Normal and Exponential" begin # issue 101
dists = [Normal(μ[1][1]), Exponential(1.0)]
dists_guess = [Normal(μ_guess[1][1]), Exponential(0.8)]

Expand Down

0 comments on commit 1c6ee9a

Please sign in to comment.