Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Time Step #290

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/Interfaces/stochasticstyles.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
StochasticStyle(v)
StochasticStyle(v::AbstractDVec)

Abstract type. When called as a function it returns the native style of the
generalised vector `v` that determines how simulations are to proceed.
Expand Down Expand Up @@ -34,12 +34,11 @@ and optionally
* [`CompressionStrategy(::StochasticStyle)`](@ref) for vector compression after
annihilations,

See also [`StochasticStyles`](@ref Main.StochasticStyles), [`Interfaces`](@ref).
See also [`StochasticStyles`](@ref Main.StochasticStyles), [`Interfaces`](@ref),
[`AbstractDVec`](@ref).
"""
abstract type StochasticStyle{T} end

StochasticStyle(::AbstractVector{T}) where T = default_style(T)

Base.eltype(::Type{<:StochasticStyle{T}}) where {T} = T
VectorInterface.scalartype(::Type{<:StochasticStyle{T}}) where {T} = T

Expand Down Expand Up @@ -127,6 +126,12 @@ end
Return a tuple of stat names (`Symbol` or `String`) and a tuple of zeros of the same
length. These will be reported as columns in the `DataFrame` returned by
[`ProjectorMonteCarloProblem`](@ref Main.ProjectorMonteCarloProblem).
The names should be unique and not contain spaces or special characters.

For a `StochasticStyle`, the first three stats are the number of
clones, deaths, and zombies.

See also [`StochasticStyle`](@ref), [`CompressionStrategy`](@ref).
"""
step_stats(v) = step_stats(StochasticStyle(v))

Expand Down
2 changes: 1 addition & 1 deletion src/Rimu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export ReportingStrategy, ReportDFAndInfo, ReportToFile
export ReplicaStrategy, NoStats, AllOverlaps
export PostStepStrategy, Projector, ProjectedEnergy, SignCoherence, WalkerLoneliness, Timer,
SingleParticleDensity, single_particle_density
export TimeStepStrategy, ConstantTimeStep, OvershootControl
export TimeStepStrategy, ConstantTimeStep, AdaptiveTimeStep
export localpart, walkernumber
export smart_logger, default_logger
export ProjectorMonteCarloProblem, SimulationPlan, state_vectors
Expand Down
44 changes: 24 additions & 20 deletions src/StochasticStyles/styles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ IsStochasticInteger() = IsStochasticInteger{Int}()
function step_stats(::IsStochasticInteger{T}) where {T}
z = zero(T)
return (
(:spawn_attempts, :spawns, :deaths, :clones, :zombies),
MultiScalar(0, z, z, z, z),
(:clones, :deaths, :zombies, :spawn_attempts, :spawns),
MultiScalar(z, z, z, 0, z),
)
end
function apply_column!(::IsStochasticInteger, w, op, add, val::Real, boost=1)
clones, deaths, zombies = diagonal_step!(w, op, add, val)
attempts, spawns = spawn!(WithReplacement(), w, op, add, val, boost)
return (attempts, spawns, deaths, clones, zombies)
return (clones, deaths, zombies, attempts, spawns)
end

"""
Expand All @@ -43,7 +43,7 @@ IsStochastic2Pop() = IsStochastic2Pop{Complex{Int}}()
function step_stats(::IsStochastic2Pop{T}) where {T}
z = zero(T)
return (
(:spawns, :deaths, :clones, :zombies),
(:clones, :deaths, :zombies, :spawns),
MultiScalar(z, z, z, z)
)
end
Expand All @@ -59,7 +59,7 @@ function apply_column!(::IsStochastic2Pop, w, op, add, val, boost=1)

clones, deaths, zombies = diagonal_step!(w, op, add, val)

return (spawns, deaths, clones, zombies)
return (clones, deaths, zombies, spawns)
end

const FloatOrComplexFloat = Union{AbstractFloat, Complex{<:AbstractFloat}}
Expand Down Expand Up @@ -91,17 +91,17 @@ function Base.show(io::IO, s::IsDeterministic{T}) where {T}
end
end

function step_stats(::IsDeterministic)
return (:exact_steps,), MultiScalar(0,)
end
function apply_column!(::IsDeterministic, w, op::AbstractMatrix, add, val, boost=1)
w .+= op[:, add] .* val
return (1,)
function step_stats(::IsDeterministic{T}) where {T}
z = zero(T)
return (
(:clones, :deaths, :zombies, :exact_steps,),
MultiScalar(z, z, z, 0)
)
end
function apply_column!(::IsDeterministic, w, op, add, val, boost=1)
diagonal_step!(w, op, add, val)
clones, deaths, zombies = diagonal_step!(w, op, add, val)
spawn!(Exact(), w, op, add, val)
return (1,)
return (clones, deaths, zombies, 1)
end

"""
Expand All @@ -121,12 +121,16 @@ IsStochasticWithThreshold(args...) = IsStochasticWithThreshold{Float64}(args...)
IsStochasticWithThreshold{T}(t=1.0) where {T} = IsStochasticWithThreshold{T}(T(t))

function step_stats(::IsStochasticWithThreshold{T}) where {T}
return ((:spawn_attempts, :spawns), MultiScalar(0, zero(T)))
z = zero(T)
return (
(:clones, :deaths, :zombies, :spawn_attempts, :spawns),
MultiScalar(z, z, z, 0, z)
)
end
function apply_column!(s::IsStochasticWithThreshold, w, op, add, val, boost=1)
diagonal_step!(w, op, add, val, s.threshold)
clones, deaths, zombies = diagonal_step!(w, op, add, val, s.threshold)
attempts, spawns = spawn!(WithReplacement(s.threshold), w, op, add, val, boost)
return (attempts, spawns)
return (clones, deaths, zombies, attempts, spawns)
end

"""
Expand Down Expand Up @@ -203,14 +207,14 @@ CompressionStrategy(s::IsDynamicSemistochastic) = s.compression
function step_stats(::IsDynamicSemistochastic{T}) where {T}
z = zero(T)
return (
(:exact_steps, :inexact_steps, :spawn_attempts, :spawns),
MultiScalar(0, 0, 0, z),
(:clones, :deaths, :zombies, :exact_steps, :inexact_steps, :spawn_attempts, :spawns),
MultiScalar(z, z, z, 0, 0, 0, z),
)
end
function apply_column!(s::IsDynamicSemistochastic, w, op, add, val, boost=1)
diagonal_step!(w, op, add, val, s.proj_threshold)
clones, deaths, zombies = diagonal_step!(w, op, add, val, s.proj_threshold)
exact, inexact, attempts, spawns = spawn!(s.spawning, w, op, add, val, boost)
return (exact, inexact, attempts, spawns)
return (clones, deaths, zombies, exact, inexact, attempts, spawns)
end

default_style(::Type{T}) where {T<:Integer} = IsStochasticInteger{T}()
Expand Down
12 changes: 8 additions & 4 deletions src/fciqmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,21 @@ function advance!(algorithm::FCIQMC, report, state::ReplicaState, s_state::Singl
# pv was mutated and now contains the new vector.
v, pv = (pv, v)

deaths, clones, zombies = step_stat_values[1:3] # stats from the StochasticStyle

# Stats:
tnorm, len = walkernumber_and_length(v)

# Updates
time_step = update_time_step(time_step_strategy, time_step, tnorm)

shift_stats, proceed = update_shift_parameters!(
shift_strategy, shift_parameters, tnorm, v, pv, step, report
new_time_step = update_time_step(
time_step_strategy, time_step, deaths, clones, zombies, tnorm, len
)

@pack! s_state = v, pv, wm

shift_stats, proceed = update_shift_parameters!(
shift_strategy, shift_parameters, new_time_step, tnorm, s_state, step
)
### TO HERE

if step % reporting_interval(state.reporting_strategy) == 0
Expand Down
2 changes: 1 addition & 1 deletion src/projector_monte_carlo_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ julia> simulation.success[]
true

julia> size(DataFrame(simulation))
(100, 9)
(100, 12)
```

# Further keyword arguments:
Expand Down
58 changes: 37 additions & 21 deletions src/strategies_and_params/shiftstrategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ end
update_shift_parameters!(
s <: ShiftStrategy,
shift_parameters,
new_time_step,
tnorm,
v_new,
v_old,
step,
report
single_state,
step
) -> shift_stats, proceed
Update the `shift_parameters` according to strategy `s`. See [`ShiftStrategy`](@ref).
Returns a named tuple of the shift statistics and a boolean `proceed` indicating whether
Expand Down Expand Up @@ -86,7 +85,8 @@ function DontUpdate(; targetwalkers = nothing, target_walkers = 1_000)
end


function update_shift_parameters!(s::DontUpdate, sp, tnorm, _...)
function update_shift_parameters!(s::DontUpdate, sp, new_time_step, tnorm, _...)
sp.time_step = new_time_step
return (; shift=sp.shift, norm=tnorm), tnorm < s.target_walkers
end

Expand All @@ -109,15 +109,18 @@ function LogUpdateAfterTargetWalkers(; targetwalkers=nothing, target_walkers = 1
return LogUpdateAfterTargetWalkers(target_walkers, ζ)
end

function update_shift_parameters!(s::LogUpdateAfterTargetWalkers, sp, tnorm, _...)
function update_shift_parameters!(
s::LogUpdateAfterTargetWalkers, sp, new_time_step, tnorm, _...
)
@unpack shift, pnorm, time_step, shift_mode = sp
if shift_mode || real(tnorm) > s.target_walkers
shift_mode = true
dτ = time_step
shift -= s.ζ / dτ * log(tnorm / pnorm)
end
pnorm = tnorm
@pack! sp = shift, pnorm, shift_mode
time_step = new_time_step
@pack! sp = shift, pnorm, shift_mode, time_step
return (; shift, norm=tnorm, shift_mode), true
end

Expand All @@ -136,12 +139,13 @@ Base.@kwdef struct LogUpdate <: ShiftStrategy
ζ::Float64 = 0.08 # damping parameter, best left at value of 0.3
end

function update_shift_parameters!(s::LogUpdate, sp, tnorm, _...)
function update_shift_parameters!(s::LogUpdate, sp, new_time_step, tnorm, _...)
@unpack shift, pnorm, time_step = sp
dτ = time_step
shift -= s.ζ / dτ * log(tnorm / pnorm)
pnorm = tnorm
@pack! sp = shift, pnorm
time_step = new_time_step
@pack! sp = shift, pnorm, time_step
return (; shift, norm=tnorm), true
end

Expand Down Expand Up @@ -171,12 +175,13 @@ function DoubleLogUpdate(;targetwalkers = nothing, target_walkers = 1_000, ζ =
return DoubleLogUpdate(target_walkers, ζ, ξ)
end

function update_shift_parameters!(s::DoubleLogUpdate, sp, tnorm, _...)
function update_shift_parameters!(s::DoubleLogUpdate, sp, new_time_step, tnorm, _...)
@unpack shift, pnorm, time_step = sp
dτ = time_step
shift -= s.ξ / dτ * log(tnorm / s.target_walkers) + s.ζ / dτ * log(tnorm / pnorm)
pnorm = tnorm
@pack! sp = shift, pnorm
time_step = new_time_step
@pack! sp = shift, pnorm, time_step
return (; shift, norm=tnorm), true
end

Expand All @@ -202,15 +207,18 @@ function DoubleLogUpdateAfterTargetWalkers(;
return DoubleLogUpdateAfterTargetWalkers(target_walkers, ζ, ξ)
end

function update_shift_parameters!(s::DoubleLogUpdateAfterTargetWalkers, sp, tnorm, _...)
function update_shift_parameters!(
s::DoubleLogUpdateAfterTargetWalkers, sp, new_time_step, tnorm, _...
)
@unpack shift, pnorm, time_step, shift_mode = sp
if shift_mode || real(tnorm) > s.target_walkers
shift_mode = true
dτ = time_step
shift -= s.ξ / dτ * log(tnorm / s.target_walkers) + s.ζ / dτ * log(tnorm / pnorm)
end
pnorm = tnorm
@pack! sp = shift, pnorm, shift_mode
time_step = new_time_step
@pack! sp = shift, pnorm, shift_mode, time_step
return (; shift, norm=tnorm, shift_mode), true
end

Expand Down Expand Up @@ -248,17 +256,21 @@ function DoubleLogSumUpdate(;
DoubleLogSumUpdate(target_walkers, ζ, ξ, α)
end

function update_shift_parameters!(s::DoubleLogSumUpdate, sp, tnorm, v_new, v_old, _...)
function update_shift_parameters!(
s::DoubleLogSumUpdate, sp, new_time_step, tnorm, s_state, _...
)
@unpack shift, pnorm, time_step = sp
@unpack v, pv = s_state
dτ = time_step
tp = DictVectors.UniformProjector() ⋅ v_new
pp = DictVectors.UniformProjector() ⋅ v_old # could be cached
tp = DictVectors.UniformProjector() ⋅ v
pp = DictVectors.UniformProjector() ⋅ pv # could be cached
twn = (1 - s.α) * tnorm + s.α * tp
pwn = (1 - s.α) * pnorm + s.α * pp
# return new shift
shift -= s.ξ / dτ * log(twn / s.target_walkers) + s.ζ / dτ * log(twn / pwn)
pnorm = tnorm
@pack! sp = shift, pnorm
time_step = new_time_step
@pack! sp = shift, pnorm, time_step
return (; shift, norm=tnorm, up=tp), true
end

Expand Down Expand Up @@ -287,14 +299,18 @@ function DoubleLogProjected(; target, projector, ζ = 0.08, ξ = ζ^2/4)
return DoubleLogProjected(target, freeze(projector), ζ, ξ)
end

function update_shift_parameters!(s::DoubleLogProjected, sp, tnorm, v_new, v_old, _...)
function update_shift_parameters!(
s::DoubleLogProjected, sp, new_time_step, tnorm, s_state, _...
)
@unpack shift, pnorm, time_step = sp
@unpack v, pv = s_state
dτ = time_step
tp = s.projector ⋅ v_new
pp = s.projector ⋅ v_old
tp = s.projector ⋅ v
pp = s.projector ⋅ pv
# return new shift
shift -= s.ξ / dτ * log(tp / s.target) + s.ζ / dτ * log(tp / pp)
pnorm = tnorm
@pack! sp = shift, pnorm
time_step = new_time_step
@pack! sp = shift, pnorm, time_step
return (; shift, norm=tnorm, tp, pp), true
end
18 changes: 17 additions & 1 deletion src/strategies_and_params/timestepstrategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,23 @@ Keep `time_step` constant.
struct ConstantTimeStep <: TimeStepStrategy end

"""
update_time_step(s<:TimeStepStrategy, time_step, tnorm) -> new_time_step
update_time_step(s<:TimeStepStrategy, time_step, deaths, clones, zombies, tnorm, len)
-> new_time_step
Update the time step according to the strategy `s`.
"""
update_time_step(::ConstantTimeStep, time_step, args...) = time_step

"""
AdaptiveTimeStep <: TimeStepStrategy

Adapt the time step to avoid zombies.
"""
struct AdaptiveTimeStep <: TimeStepStrategy end

function update_time_step(::AdaptiveTimeStep, time_step, _, _, zombies, _...)
if zombies > 0
return time_step * 0.9^zombies
else
return time_step * 1.01 # increase by 1%
end
end
4 changes: 2 additions & 2 deletions test/RMPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ using Test
dv = DVec(starting_address(ham) => 10; style=IsDynamicSemistochastic())
v = MPIData(dv; setup)
df, state = lomc!(ham,v)
@test size(df) == (100, 9)
@test size(df) == (100, 12)
end
# need to do mpi_one_sided separately
dv = DVec(starting_address(ham)=>10; style=IsDynamicSemistochastic())
v = RMPI.mpi_one_sided(dv; capacity = 1000)
df, state = lomc!(ham,v)
@test size(df) == (100, 9)
@test size(df) == (100, 12)
end

@testset "sort_and_count!" begin
Expand Down
14 changes: 0 additions & 14 deletions test/StochasticStyles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ end
deposit!(vec, 1, 2, 1 => 2)
deposit!(vec, 4, -2, 1 => 2)
@test vec == [3, 2, 3, 2, 5]

@test StochasticStyle(vec) == IsStochasticInteger{Int64}()
@test StochasticStyle(Float32.(vec)) == IsDeterministic{Float32}()

names, values = step_stats(vec)
@test names == (:spawn_attempts, :spawns, :deaths, :clones, :zombies)
@test values == Rimu.MultiScalar((0, 0, 0, 0, 0))

w = [1.0, 2.0, 3.0]

@test apply_column!(w, matrix, 1, 2) == (1, )
@test w[1] == 1.0 + 2 * matrix[1, 1]
@test w[2] == 2.0 + 2 * matrix[2, 1]
@test w[3] == 3.0 + 2 * matrix[3, 1]
end

@testset "projected_deposit!" begin
Expand Down
Loading
Loading