Skip to content

Commit

Permalink
First pass on ARHMM using AbstractHMM. I have not changed the samplin…
Browse files Browse the repository at this point in the history
…g function yet, and the results are not checked for correctness.
  • Loading branch information
fausto-mpj committed Nov 19, 2024
1 parent 8ee1439 commit 5888f79
Show file tree
Hide file tree
Showing 6 changed files with 1,801 additions and 65 deletions.
4 changes: 2 additions & 2 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function initialize_forward(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
α = Matrix{R}(undef, N, T)
logL = Vector{R}(undef, K)
Expand Down Expand Up @@ -100,7 +100,7 @@ function _forward!(
αₜ = view(α, :, t)
Bₜ = view(B, :, t)
if t == t1
copyto!(αₜ, initialization(hmm))
copyto!(αₜ, initialization(hmm, control_seq[t]))
else
αₜ₋₁ = view(α, :, t - 1)
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1])
Expand Down
2 changes: 1 addition & 1 deletion src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function initialize_forward_backward(
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals=true,
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
trans = transition_matrix(hmm, control_seq[1])
M = typeof(similar(trans, R))
Expand Down
2 changes: 1 addition & 1 deletion src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function joint_logdensityof(
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
# Initialization
init = initialization(hmm)
init = initialization(hmm, control_seq[t1])
logL += log(init[state_seq[t1]])
# Transitions
for t in t1:(t2 - 1)
Expand Down
4 changes: 2 additions & 2 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function initialize_viterbi(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
q = Vector{Int}(undef, T)
logL = Vector{R}(undef, K)
Expand All @@ -49,7 +49,7 @@ function _viterbi!(

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing)
loginit = log_initialization(hmm)
loginit = log_initialization(hmm, control_seq[t1])
ϕ[:, t1] .= loginit .+ logBₜ₁

for t in (t1 + 1):t2
Expand Down
216 changes: 157 additions & 59 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
############################################################################
# 1. TYPE #
############################################################################
"""
AbstractHMM
Expand All @@ -23,19 +26,29 @@ Any `AbstractHMM` which satisfies the interface can be given to the following fu
- [`forward_backward`](@ref)
- [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented)
"""
abstract type AbstractHMM{ar} end
abstract type AbstractHMM{T} end

@inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity()

## Interface

############################################################################
# 2. INTERFACE #
############################################################################

#------------------------------ 2.1. length -------------------------------#
"""
length(hmm)
Return the number of states of `hmm`.
"""
Base.length(hmm::AbstractHMM) = length(initialization(hmm))

Base.length(hmm::AbstractHMM, control) = length(initialization(hmm, control))

Base.length(hmm::AbstractHMM, ::Nothing) = length(initialization(hmm))


#------------------------------ 2.2. eltype -------------------------------#
"""
eltype(hmm, obs, control)
Expand All @@ -44,20 +57,30 @@ Return a type that can accommodate forward-backward computations for `hmm` on ob
It is typically a promotion between the element type of the initialization, the element type of the transition matrix, and the type of an observation logdensity evaluated at `obs`.
"""
function Base.eltype(hmm::AbstractHMM, obs, control)
init_type = eltype(initialization(hmm))
trans_type = eltype(transition_matrix(hmm, control))
dist = obs_distributions(hmm, control, obs)[1]
logdensity_type = typeof(logdensityof(dist, obs))
return promote_type(init_type, trans_type, logdensity_type)
init_type = eltype(initialization(hmm, control))
trans_type = eltype(transition_matrix(hmm, control))
dist = obs_distributions(hmm, control, obs)[1]
logdensity_type = typeof(logdensityof(dist, obs))
return promote_type(init_type, trans_type, logdensity_type)
end


#--------------------------- 2.3. initialization --------------------------#
"""
initialization(hmm)
Return the vector of initial state probabilities for `hmm`.
"""
function initialization end

initialization(hmm::AbstractHMM) = hmm.init

initialization(hmm::AbstractHMM, control) = hmm.init

initialization(hmm::AbstractHMM, ::Nothing) = hmm.init


#------------------------ 2.4. log_initialization -------------------------#
"""
log_initialization(hmm)
Expand All @@ -66,18 +89,32 @@ Return the vector of initial state log-probabilities for `hmm`.
Falls back on `initialization`.
"""
log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm))

log_initialization(hmm::AbstractHMM, control) = elementwise_log(initialization(hmm, control))

log_initialization(hmm::AbstractHMM, ::Nothing) = elementwise_log(initialization(hmm))


#------------------------- 2.5. transition_matrix -------------------------#
"""
transition_matrix(hmm)
transition_matrix(hmm, control)
Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied).
!!! note
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
"""
function transition_matrix end

transition_matrix(hmm::AbstractHMM) = hmm.trans

transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm)

transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm)


#----------------------- 2.6. log_transition_matrix -----------------------#
"""
log_transition_matrix(hmm)
log_transition_matrix(hmm, control)
Expand All @@ -89,10 +126,15 @@ Falls back on `transition_matrix`.
!!! note
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
"""
function log_transition_matrix(hmm::AbstractHMM, control)
return elementwise_log(transition_matrix(hmm, control))
end
log_transition_matrix(hmm::AbstractHMM) = elementwise_log(transition_matrix(hmm))

log_transition_matrix(hmm::AbstractHMM, control) = elementwise_log(transition_matrix(hmm, control))

log_transition_matrix(hmm::AbstractHMM, ::Nothing) = elementwise_log(transition_matrix(hmm))



#------------------------- 2.7. obs_distributions -------------------------#
"""
obs_distributions(hmm)
obs_distributions(hmm, control)
Expand All @@ -107,18 +149,27 @@ These distribution objects should implement
"""
function obs_distributions end

## Fallbacks for no control
obs_distributions(hmm::AbstractHMM) = [hmm.dists[i] for i 1:length(hmm)]

transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm)
log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm)
obs_distributions(hmm::AbstractHMM, control, prev_obs) = obs_distributions(hmm, control, prev_obs)

### Fallback when it is not autoregressive and there is no control
obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm)
function obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing})
return obs_distributions(hmm, control)
end

### Fallback when it is autoregressive, but previous observation is Missing
obs_distributions(hmm::AbstractHMM, control, ::Missing) = obs_distributions(hmm, control)

### Fallback when it is autoregressive and there is no control, but observation is missing
obs_distributions(hmm::AbstractHMM, ::Nothing, ::Missing) = obs_distributions(hmm)


#---------------------------- 2.8. previous_obs ---------------------------#
previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing

previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1]


#--------------------------- 2.9. StatsAPI.fit! ---------------------------#
"""
StatsAPI.fit!(
hmm, fb_storage::ForwardBackwardStorage,
Expand All @@ -131,21 +182,25 @@ This function is allowed to reuse `fb_storage` as a scratch space, so its conten
"""
StatsAPI.fit!

## Fill logdensities

#------------------------- 2.10. obs_logdensities! ------------------------#
function obs_logdensities!(
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs
) where {T}
dists = obs_distributions(hmm, control, prev_obs)
@simd for i in eachindex(logb, dists)
logb[i] = logdensityof(dists[i], obs)
end
@argcheck maximum(logb) < typemax(T)
return nothing
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs
) where {T}
dists = obs_distributions(hmm, control, prev_obs)
@simd for i in eachindex(logb, dists)
logb[i] = logdensityof(dists[i], obs)
end
@argcheck maximum(logb) < typemax(T)
return nothing
end

## Sampling


############################################################################
# 3. SAMPLING #
############################################################################

# <------------- Didn't touch it yet!
"""
rand([rng,] hmm, T)
rand([rng,] hmm, control_seq)
Expand All @@ -155,50 +210,93 @@ Simulate `hmm` for `T` time steps, or when the sequence `control_seq` is applied
Return a named tuple `(; state_seq, obs_seq)`.
"""
function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector)
T = length(control_seq)
dummy_log_probas = fill(-Inf, length(hmm))

init = initialization(hmm)
state_seq = Vector{Int}(undef, T)
state1 = rand(rng, LightCategorical(init, dummy_log_probas))
state_seq[1] = state1

@views for t in 1:(T - 1)
trans = transition_matrix(hmm, control_seq[t])
state_seq[t + 1] = rand(
rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas)
)
end

dists1 = obs_distributions(hmm, control_seq[1], missing)
obs1 = rand(rng, dists1[state1])
obs_seq = Vector{typeof(obs1)}(undef, T)
obs_seq[1] = obs1

for t in 2:T
dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t))
obs_seq[t] = rand(rng, dists[state_seq[t]])
end
return (; state_seq=state_seq, obs_seq=obs_seq)
T = length(control_seq)
dummy_log_probas = fill(-Inf, length(hmm))

init = initialization(hmm, control)
state_seq = Vector{Int}(undef, T)
state1 = rand(rng, LightCategorical(init, dummy_log_probas))
state_seq[1] = state1

@views for t in 1:(T - 1)
trans = transition_matrix(hmm, control_seq[t])
state_seq[t + 1] = rand(
rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas)
)
end

dists1 = obs_distributions(hmm, control_seq[1], missing)
obs1 = rand(rng, dists1[state1])
obs_seq = Vector{typeof(obs1)}(undef, T)
obs_seq[1] = obs1

for t in 2:T
dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t))
obs_seq[t] = rand(rng, dists[state_seq[t]])
end
return (; state_seq=state_seq, obs_seq=obs_seq)
end

function Random.rand(hmm::AbstractHMM, control_seq::AbstractVector)
return rand(default_rng(), hmm, control_seq)
return rand(default_rng(), hmm, control_seq)
end

function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer)
return rand(rng, hmm, Fill(nothing, T))
return rand(rng, hmm, Fill(nothing, T))
end

function Random.rand(hmm::AbstractHMM, T::Integer)
return rand(hmm, Fill(nothing, T))
return rand(hmm, Fill(nothing, T))
end

## Prior


############################################################################
# 4. PRIOR #
############################################################################
"""
logdensityof(hmm)
Return the prior loglikelihood associated with the parameters of `hmm`.
"""
DensityInterface.logdensityof(hmm::AbstractHMM) = false


############################################################################
# 5. ARHMM EXAMPLE #
############################################################################
## Test scruct for Discrete ARHMM with control
struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true}
# Initial distribution P(X_{1}|U_{1}), one vector for each control
init::Vector{Vector{T}}
# Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control
trans::Vector{Matrix{T}}
# Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation
dists::Vector{Vector{Matrix{T}}}
# Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control
prior::Vector{Matrix{T}}
end

initialization(hmm::DiscreteCARHMM, control) = hmm.init[control]

transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control]

obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) = [Categorical(hmm.dists[control][prev_obs][i,:]) for i in 1:length(hmm, control)]

obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) = [Categorical(hmm.prior[control][i,:]) for i in 1:length(hmm, control)]


## Test scruct for Discrete ARHMM without control
struct DiscreteARHMM{T<:Number} <: AbstractHMM{true}
# Initial distribution P(X_{1})
init::Vector{T}
# Transition matrix P(X_{t}|X_{t-1})
trans::Matrix{T}
# Emission matriz P(Y_{t}|X_{t})
dists::Vector{Matrix{T}}
# Prior Distribution for P(Y_{1}|X_{1})
prior::Matrix{T}
end

obs_distributions(hmm::DiscreteARHMM, ::Nothing, prev_obs) = [Categorical(hmm.dists[prev_obs][i,:]) for i in 1:length(hmm)]

obs_distributions(hmm::DiscreteARHMM, ::Nothing, ::Missing) = [Categorical(hmm.prior[i,:]) for i in 1:length(hmm)]
Loading

0 comments on commit 5888f79

Please sign in to comment.