From 94d4e4814d13c7093d4ea4abab47b7610ae2b9e2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:27:41 +0100 Subject: [PATCH] Final benchmarks --- benchmark/Manifest.toml | 16 ++++---- benchmark/Project.toml | 3 ++ libs/HMMBenchmark/Project.toml | 1 + libs/HMMBenchmark/src/hiddenmarkovmodels.jl | 16 ++++---- libs/HMMComparison/Project.toml | 5 ++- .../{performance.jl => measurements.jl} | 8 ++-- libs/HMMComparison/experiments/plots.jl | 39 +++++++++++-------- libs/HMMComparison/src/dynamax.jl | 8 ++-- libs/HMMComparison/src/hmmbase.jl | 8 ++-- libs/HMMComparison/src/hmmlearn.jl | 8 ++-- libs/HMMComparison/src/pomegranate.jl | 6 +-- libs/HMMComparison/test/runtests.jl | 2 +- 12 files changed, 66 insertions(+), 54 deletions(-) rename libs/HMMComparison/experiments/{performance.jl => measurements.jl} (93%) diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 4d1de94f..b6079e8d 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -21,9 +21,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[deps.BenchmarkTools]] deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "f1f03a9fa24271160ed7e73051fba3c1a759b53f" +git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f" uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.4.0" +version = "1.5.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] @@ -39,9 +39,9 @@ version = "0.5.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "892b245fdec1c511906671b6a5e1bafa38a727c1" +git-tree-sha1 = "aef70bb349b20aa81a82a19704c3ef339d4ee494" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.22.0" +version = "1.22.1" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -55,9 +55,9 @@ version = "0.7.4" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "d2c021fbdde94f6cdaa799639adfeeaa17fd67f5" +git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.13.0" +version = "4.14.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -363,9 +363,9 @@ version = "1.2.0" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "9e8fed0505b0c15b4c1295fd59ea47b411c019cf" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" +version = "1.4.2" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] diff --git a/benchmark/Project.toml b/benchmark/Project.toml index e3641866..8bafaa15 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -5,3 +5,6 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" + +[compat] +BenchmarkTools = "1.5" diff --git a/libs/HMMBenchmark/Project.toml b/libs/HMMBenchmark/Project.toml index 48ccdc9c..7abc70e5 100644 --- a/libs/HMMBenchmark/Project.toml +++ b/libs/HMMBenchmark/Project.toml @@ -18,4 +18,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +BenchmarkTools = "1.5" julia = "1.9" diff --git a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 88d6dfa8..c8b54f3f 100644 --- a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -43,12 +43,12 @@ function build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward!" in algos benchs["forward!"] = @benchmarkable begin forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -56,12 +56,12 @@ function build_benchmarkables( if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "viterbi!" in algos benchs["viterbi!"] = @benchmarkable begin viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -69,12 +69,12 @@ function build_benchmarkables( if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward_backward!" in algos benchs["forward_backward!"] = @benchmarkable begin forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( fb_storage = initialize_forward_backward( $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) @@ -92,7 +92,7 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "baum_welch!" in algos benchs["baum_welch!"] = @benchmarkable begin @@ -107,7 +107,7 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( hmm_guess = build_model($implem, $instance, $params); fb_storage = initialize_forward_backward( hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends diff --git a/libs/HMMComparison/Project.toml b/libs/HMMComparison/Project.toml index 8e8e3d4f..03586463 100644 --- a/libs/HMMComparison/Project.toml +++ b/libs/HMMComparison/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -13,9 +14,11 @@ HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" + +[compat] +BenchmarkTools = "1.5" diff --git a/libs/HMMComparison/experiments/performance.jl b/libs/HMMComparison/experiments/measurements.jl similarity index 93% rename from libs/HMMComparison/experiments/performance.jl rename to libs/HMMComparison/experiments/measurements.jl index fe6ed1ba..7736f14f 100644 --- a/libs/HMMComparison/experiments/performance.jl +++ b/libs/HMMComparison/experiments/measurements.jl @@ -41,16 +41,16 @@ algos = ["forward", "viterbi", "forward_backward", "baum_welch"] instances = Instance[] -for nb_states in 2:2:16 +for nb_states in 2:2:10 push!( instances, Instance(; - custom_dist=true, + custom_dist=false, sparse=false, nb_states=nb_states, obs_dim=1, - seq_length=200, - nb_seqs=100, + seq_length=100, + nb_seqs=50, bw_iter=5, ), ) diff --git a/libs/HMMComparison/experiments/plots.jl b/libs/HMMComparison/experiments/plots.jl index f922c15f..cd2f91d6 100644 --- a/libs/HMMComparison/experiments/plots.jl +++ b/libs/HMMComparison/experiments/plots.jl @@ -1,5 +1,5 @@ using DataFrames -using Plots +using CairoMakie using HMMComparison data = read_results(joinpath(@__DIR__, "results", "results.csv")) @@ -15,33 +15,38 @@ implems = [ ] algos = ["viterbi", "forward", "forward_backward", "baum_welch"] -markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle] +markers = [:star5, :circle, :diamond, :hexagon, :pentagon] +linestyles = [nothing, :dot, :dash, :dashdot, :dashdotdot] -for algo in algos - pl = plot(; +fig = Figure(size=(900, 700)) +ax = nothing +for (k, algo) in enumerate(algos) + ax = Axis( + fig[fld1(k, 2), mod1(k, 2)]; title=algo, - size=(800, 400), - yscale=:log, xlabel="nb states", ylabel="runtime (s)", + yscale=log10, xticks=unique(data[!, :nb_states]), - legend=:outerright, - margin=5Plots.mm, + yminorticksvisible = true, + yminorgridvisible = true, + yminorticks = IntervalsBetween(5) ) + for (i, implem) in enumerate(implems) subdata = data[(data[!, :algo] .== algo) .& (data[!, :implem] .== implem), :] - plot!( - pl, + scatterlines!( + ax, subdata[!, :nb_states], subdata[!, :time_median] ./ 1e9; - label=implem, - markershape=markershapes[i], - markerstrokecolor=:auto, - markersize=5, - linestyle=:auto, linewidth=2, + linestyle=linestyles[i], + marker=markers[i], + markersize=15, + label=implem, ) end - display(pl) - savefig(pl, joinpath(@__DIR__, "results", "$(algo).pdf")) end +Legend(fig[3, 1:2], ax, orientation = :horizontal) +fig +save(joinpath(@__DIR__, "results", "benchmark.pdf"), fig) diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index 4a845911..3776312d 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables( filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin $(filter_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "viterbi" in algos @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables( ) benchs["viterbi"] = @benchmarkable begin $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward_backward" in algos @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables( ) benchs["forward_backward"] = @benchmarkable begin $(smoother_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "baum_welch" in algos @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables( num_iters=$bw_iter, verbose=false, ) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( tup = build_model($implem, $instance, $params); hmm_guess = tup[1]; dyn_params_guess = tup[2]; diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index 4608071a..33144ad0 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -41,7 +41,7 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.forward($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "viterbi" in algos @@ -49,7 +49,7 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.viterbi($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward_backward" in algos @@ -57,13 +57,13 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.posteriors($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end return benchs diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index 610c9129..dd951d02 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -45,25 +45,25 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin $(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index d5939919..aa0fa0bd 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -57,19 +57,19 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.forward)($obs_tens_torch_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.forward_backward)($obs_tens_torch_py) - end evals = 1 samples = 100 + end evals = 1 samples = 20 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_tens_torch_py) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 20 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/test/runtests.jl b/libs/HMMComparison/test/runtests.jl index 6799f710..8929d7f4 100644 --- a/libs/HMMComparison/test/runtests.jl +++ b/libs/HMMComparison/test/runtests.jl @@ -6,7 +6,7 @@ rng = Random.default_rng() @testset "HMMComparison" begin instance = Instance(; - custom_dist=false, sparse=false, nb_states=5, obs_dim=10, seq_length=25, nb_seqs=10 + custom_dist=true, sparse=false, nb_states=5, obs_dim=10, seq_length=25, nb_seqs=10 ) params = build_params(rng, instance) data = build_data(rng, instance)