diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 27c928896..0fda877fc 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -329,6 +329,27 @@ using Turing @test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001 end + @testset "getstepsize: Turing.jl#2400" begin + algs = [ + HMC(0.1, 10), + HMCDA(0.8, 0.75), + NUTS(0.5), + NUTS(0, 0.5), + ] + @testset "$(alg)" for alg in algs + # Construct a HMC state by taking a single step + spl = Sampler(alg, gdemo_default) + hmc_state = DynamicPPL.initialstep( + Random.default_rng(), + gdemo_default, + spl, + DynamicPPL.VarInfo(gdemo_default) + )[2] + # Check that we can obtain the current step size + @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 + end + end + @testset "Check ADType" begin alg = HMC(0.1, 10; adtype=adbackend) m = DynamicPPL.contextualize(