Skip to content

Commit

Permalink
Final benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 29, 2024
1 parent 2f975b6 commit 94d4e48
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 54 deletions.
16 changes: 8 additions & 8 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
git-tree-sha1 = "f1f03a9fa24271160ed7e73051fba3c1a759b53f"
git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "1.4.0"
version = "1.5.0"

[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
Expand All @@ -39,9 +39,9 @@ version = "0.5.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "892b245fdec1c511906671b6a5e1bafa38a727c1"
git-tree-sha1 = "aef70bb349b20aa81a82a19704c3ef339d4ee494"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.22.0"
version = "1.22.1"
weakdeps = ["SparseArrays"]

[deps.ChainRulesCore.extensions]
Expand All @@ -55,9 +55,9 @@ version = "0.7.4"

[[deps.Compat]]
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "d2c021fbdde94f6cdaa799639adfeeaa17fd67f5"
git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.13.0"
version = "4.14.0"
weakdeps = ["Dates", "LinearAlgebra"]

[deps.Compat.extensions]
Expand Down Expand Up @@ -363,9 +363,9 @@ version = "1.2.0"

[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
git-tree-sha1 = "9e8fed0505b0c15b4c1295fd59ea47b411c019cf"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.1"
version = "1.4.2"

[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
Expand Down
3 changes: 3 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
BenchmarkTools = "1.5"
1 change: 1 addition & 0 deletions libs/HMMBenchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
BenchmarkTools = "1.5"
julia = "1.9"
16 changes: 8 additions & 8 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 20
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 20
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 20
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -92,7 +92,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100
end evals = 1 samples = 20
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
Expand All @@ -107,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
5 changes: 4 additions & 1 deletion libs/HMMComparison/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -13,9 +14,11 @@ HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
BenchmarkTools = "1.5"
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ algos = ["forward", "viterbi", "forward_backward", "baum_welch"]

instances = Instance[]

for nb_states in 2:2:16
for nb_states in 2:2:10
push!(
instances,
Instance(;
custom_dist=true,
custom_dist=false,
sparse=false,
nb_states=nb_states,
obs_dim=1,
seq_length=200,
nb_seqs=100,
seq_length=100,
nb_seqs=50,
bw_iter=5,
),
)
Expand Down
39 changes: 22 additions & 17 deletions libs/HMMComparison/experiments/plots.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataFrames
using Plots
using CairoMakie
using HMMComparison

data = read_results(joinpath(@__DIR__, "results", "results.csv"))
Expand All @@ -15,33 +15,38 @@ implems = [
]
algos = ["viterbi", "forward", "forward_backward", "baum_welch"]

markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle]
markers = [:star5, :circle, :diamond, :hexagon, :pentagon]
linestyles = [nothing, :dot, :dash, :dashdot, :dashdotdot]

for algo in algos
pl = plot(;
fig = Figure(size=(900, 700))
ax = nothing
for (k, algo) in enumerate(algos)
ax = Axis(
fig[fld1(k, 2), mod1(k, 2)];
title=algo,
size=(800, 400),
yscale=:log,
xlabel="nb states",
ylabel="runtime (s)",
yscale=log10,
xticks=unique(data[!, :nb_states]),
legend=:outerright,
margin=5Plots.mm,
yminorticksvisible = true,
yminorgridvisible = true,
yminorticks = IntervalsBetween(5)
)

for (i, implem) in enumerate(implems)
subdata = data[(data[!, :algo] .== algo) .& (data[!, :implem] .== implem), :]
plot!(
pl,
scatterlines!(
ax,
subdata[!, :nb_states],
subdata[!, :time_median] ./ 1e9;
label=implem,
markershape=markershapes[i],
markerstrokecolor=:auto,
markersize=5,
linestyle=:auto,
linewidth=2,
linestyle=linestyles[i],
marker=markers[i],
markersize=15,
label=implem,
)
end
display(pl)
savefig(pl, joinpath(@__DIR__, "results", "$(algo).pdf"))
end
Legend(fig[3, 1:2], ax, orientation = :horizontal)
fig
save(joinpath(@__DIR__, "results", "benchmark.pdf"), fig)
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/dynamax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables(
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["forward"] = @benchmarkable begin
$(filter_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "viterbi" in algos
Expand All @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["viterbi"] = @benchmarkable begin
$(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "forward_backward" in algos
Expand All @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["forward_backward"] = @benchmarkable begin
$(smoother_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "baum_welch" in algos
Expand All @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables(
num_iters=$bw_iter,
verbose=false,
)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
tup = build_model($implem, $instance, $params);
hmm_guess = tup[1];
dyn_params_guess = tup[2];
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ function HMMBenchmark.build_benchmarkables(
@threads for k in eachindex($obs_mats)
HMMBase.forward($hmm, $obs_mats[k])
end
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.viterbi($hmm, $obs_mats[k])
end
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.posteriors($hmm, $obs_mats[k])
end
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

return benchs
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmlearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,25 @@ function HMMBenchmark.build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
$(hmm.score)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
$(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
$(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
hmm_guess = build_model($implem, $instance, $params)
)
end
Expand Down
6 changes: 3 additions & 3 deletions libs/HMMComparison/src/pomegranate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ function HMMBenchmark.build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
$(hmm.forward)($obs_tens_torch_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
$(hmm.forward_backward)($obs_tens_torch_py)
end evals = 1 samples = 100
end evals = 1 samples = 20
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
hmm_guess.fit($obs_tens_torch_py)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 20 setup = (
hmm_guess = build_model($implem, $instance, $params)
)
end
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMComparison/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ rng = Random.default_rng()

@testset "HMMComparison" begin
instance = Instance(;
custom_dist=false, sparse=false, nb_states=5, obs_dim=10, seq_length=25, nb_seqs=10
custom_dist=true, sparse=false, nb_states=5, obs_dim=10, seq_length=25, nb_seqs=10
)
params = build_params(rng, instance)
data = build_data(rng, instance)
Expand Down

0 comments on commit 94d4e48

Please sign in to comment.