Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow simulation to be restored with new interventions. #286

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Remotes:
Additional_repositories:
https://mrc-ide.r-universe.dev
Imports:
individual (>= 0.1.15),
individual (>= 0.1.16),
malariaEquilibrium (>= 1.0.1),
Rcpp,
statmod,
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ solver_get_states <- function(solver) {
.Call(`_malariasimulation_solver_get_states`, solver)
}

solver_set_states <- function(solver, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, state))
solver_set_states <- function(solver, t, state) {
invisible(.Call(`_malariasimulation_solver_set_states`, solver, t, state))
}

solver_step <- function(solver) {
Expand Down
8 changes: 4 additions & 4 deletions R/compartmental.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ Solver <- R6::R6Class(
save_state = function() {
solver_get_states(private$.solver)
},
restore_state = function(state) {
solver_set_states(private$.solver, state)
restore_state = function(t, state) {
solver_set_states(private$.solver, t, state)
}
)
)
Expand All @@ -173,7 +173,7 @@ AquaticMosquitoModel <- R6::R6Class(
# state of the ODE is stored separately). We still provide these methods to
# conform to the expected interface.
save_state = function() { NULL },
restore_state = function(state) { }
restore_state = function(t, state) { }
)
)

Expand All @@ -187,7 +187,7 @@ AdultMosquitoModel <- R6::R6Class(
save_state = function() {
adult_mosquito_model_save_state(self$.model)
},
restore_state = function(state) {
restore_state = function(t, state) {
adult_mosquito_model_restore_state(self$.model, state)
}
)
Expand Down
171 changes: 151 additions & 20 deletions R/correlation.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,26 @@ INTS <- c(
)

#' Class: Correlation parameters
#' Describes an event in the simulation
#'
#' This class implements functionality that allows interventions to be
#' correlated, positively or negatively. By default, interventions are applied
#' independently and an individual's probability of receiving two interventions
#' (either two separate interventions or two rounds of the same one) is the
#' product of the probability of receiving each one.
#'
#' By setting a positive correlation between two interventions, we can make it
#' so that the individuals that receive intervention A are more likely to
#' receive intervention B. Conversely, a negative correlation will make it such
#' that individuals that receive intervention A are less likely to also receive
#' intervention B.
#'
#' Broadly speaking, the implementation works by assigning at startup a weight
#' to each individual and intervention pair, reflecting how likely an individual
#' is to receive that intervention. Those weights are derived stochastically
#' from the configured correlation parameters.
#'
#' For a detailed breakdown of the calculations, see Protocol S2 of
#' Griffin et al. (2010).
CorrelationParameters <- R6::R6Class(
'CorrelationParameters',
private = list(
Expand All @@ -19,7 +38,40 @@ CorrelationParameters <- R6::R6Class(
rho_matrix = NULL,
rho = function() diag(private$rho_matrix),
.sigma = NULL,
.mvnorm = NULL
.mvnorm = NULL,

#' Derive the mvnorm from the configured correlations.
#'
#' If a \code{restored_mvnorm} is specified, its columns (corresponding to
#' restored interventions) will be re-used as is. Missing columns (for new
#' interventions) are derived in accordance with the restored data.
calculate_mvnorm = function(restored_mvnorm = matrix(ncol=0, nrow=private$population)) {
sigma <- self$sigma()
V <- outer(sigma, sigma) * private$rho_matrix
diag(V) <- sigma ^ 2

restored_interventions <- match(colnames(restored_mvnorm), private$interventions)
new_interventions <- setdiff(seq_along(private$interventions), restored_interventions)

mvnorm <- matrix(
nrow = private$population,
ncol = length(private$interventions),
dimnames = list(NULL, private$interventions)
)
mvnorm[,restored_interventions] <- restored_mvnorm
if (length(new_interventions) > 0) {
mvnorm[,new_interventions] <- rcondmvnorm(
private$population,
mean = rep(0, length(private$interventions)),
sigma = V,
given = restored_mvnorm,
dependent.ind = new_interventions,
given.ind = restored_interventions
)
}

mvnorm
}
),
public = list(

Expand All @@ -45,6 +97,8 @@ CorrelationParameters <- R6::R6Class(
#' @param rho value between 0 and 1 representing the correlation between rounds of
#' the intervention
inter_round_rho = function(int, rho) {
stopifnot(is.null(private$.sigma) && is.null(private$.mvnorm))

if (!(int %in% private$interventions)) {
stop(paste0('invalid intervention name: ', int))
}
Expand All @@ -55,8 +109,6 @@ CorrelationParameters <- R6::R6Class(
rho <- 1 - .Machine$double.eps
}
private$rho_matrix[[int, int]] <- rho
private$.sigma <- NULL
private$.mvnorm <- NULL
},

#' @description Add rho between interventions
Expand All @@ -66,6 +118,8 @@ CorrelationParameters <- R6::R6Class(
#' @param rho value between -1 and 1 representing the correlation between rounds of
#' the intervention
inter_intervention_rho = function(int_1, int_2, rho) {
stopifnot(is.null(private$.sigma) && is.null(private$.mvnorm))

if (!(int_1 %in% private$interventions)) {
stop(paste0('invalid intervention name: ', int_1))
}
Expand All @@ -83,8 +137,6 @@ CorrelationParameters <- R6::R6Class(
}
private$rho_matrix[[int_1, int_2]] <- rho
private$rho_matrix[[int_2, int_1]] <- rho
private$.sigma <- NULL
private$.mvnorm <- NULL
},

#' @description Standard deviation of each intervention between rounds
Expand All @@ -98,18 +150,9 @@ CorrelationParameters <- R6::R6Class(
},

#' @description multivariate norm draws for these parameters
#' @importFrom MASS mvrnorm
mvnorm = function() {
if (is.null(private$.mvnorm)) {
sigma <- self$sigma()
V <- outer(sigma, sigma) * private$rho_matrix
diag(V) <- sigma ^ 2
private$.mvnorm <- mvrnorm(
private$population,
rep(0, length(private$interventions)),
V
)
dimnames(private$.mvnorm)[[2]] <- private$interventions
private$.mvnorm <- private$calculate_mvnorm()
}
private$.mvnorm
},
Expand All @@ -121,16 +164,22 @@ CorrelationParameters <- R6::R6Class(
# otherwise we would be drawing a new, probably different, value.
# The rest of the object is derived deterministically from the parameters
# and does not need saving.
list(mvnorm=private$.mvnorm)
list(mvnorm=self$mvnorm())
},

#' @description Restore the correlation state.
#'
#' Only the randomly drawn weights are restored. The object needs to be
#' initialized with the same rhos.
#'
#' @param timestep the timestep at which simulation is resumed. This
#' parameter's value is ignored, it only exists to conform to a uniform
#' interface.
#' @param state a previously saved correlation state, as returned by the
#' save_state method.
restore_state = function(state) {
private$.mvnorm <- state$mvnorm
#' save_state method.
restore_state = function(timestep, state) {
stopifnot(is.null(private$.sigma) && is.null(private$.mvnorm))
private$.mvnorm <- private$calculate_mvnorm(state$mvnorm)
}
)
)
Expand Down Expand Up @@ -200,3 +249,85 @@ sample_intervention <- function(target, intervention, p, correlations) {
z <- rnorm(length(target))
u0 + correlations$mvnorm()[target, intervention] + z < 0
}

#' Simulate from a conditional multivariate normal distribution.
#'
#' Given a multidimensional variable Z which follows a multivariate normal
#' distribution, this function allows one to draw samples for a subset of Z,
#' while putting conditions on the values of the rest of Z.
#'
#' This effectively allows one to grow a MVN distributed matrix (with columns as
#' the dimensions and a row per sampled vector), adding new dimensions after the
#' fact. The existing columns are used as the condition set on the distribution,
#' and the values returned by this function are used as the new dimensions.
#'
#' The maths behind the implementation are described in various online sources:
#' - https://statproofbook.github.io/P/mvn-cond.html
#' - https://www.stats.ox.ac.uk/~doucet/doucet_simulationconditionalgaussian.pdf
#' - https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Conditional_distributions
#'
#' @param n the number of samples to simulate
#' @param mean the mean vector of the distribution, including both given and
#' dependent variables
#' @param sigma the variance-covariance matrix of the distribution, including
#' both given and dependent variables
#' @param given a matrix of given values used as conditions when simulating the
#' distribution. The matrix should have \code{n} rows, each one specifying a
#' different set of values for the given variables.
#' @param dependent.ind the indices within \code{mean} and \code{sigma} of the
#' variables to simulate.
#' @param given.ind the indices within \code{mean} and \code{sigma} of the
#' variables for which conditions are given. The length of this vector must be
#' equal to the number of columns of the \code{given} matrix. If empty or NULL,
#' this function is equivalent to simulating from an unconditional multivariate
#' normal distribution.
#' @return a matrix with \code{n} rows and \code{length(dependent.ind)} columns,
#' containing the simulated value.
#' @importFrom MASS mvrnorm
#' @noRd
rcondmvnorm <- function(n, mean, sigma, given, dependent.ind, given.ind) {
stopifnot(length(mean) == nrow(sigma))
stopifnot(length(mean) == ncol(sigma))
stopifnot(nrow(given) == n)
stopifnot(ncol(given) == length(given.ind))

sigma11 <- sigma[dependent.ind, dependent.ind, drop=FALSE]
sigma12 <- sigma[dependent.ind, given.ind, drop=FALSE]
sigma21 <- sigma[given.ind, dependent.ind, drop=FALSE]
sigma22 <- sigma[given.ind, given.ind, drop=FALSE]

if (all(sigma22 == 0)) {
# This covers two cases: there were no given variables and therefore their
# variance-covariance matrix is empty, or there were given variables but
# they had a variance of zero. The general formula can't support the latter
# case since it tries to invert the matrix, but we can safely ignore the
# values since they are all equal to their mean and don't influence the
# dependent variables.
#
# In both cases we revert to a standard MVN with no condition.
mvrnorm(n, mean[dependent.ind], sigma11)
} else {
# Available implementations of the conditional multivariate normal assume
# every sample is drawn using the same condition on the given variables.
# This is not true in our usecase, where every individual has already had an
# independent vector of values drawn for the given variable. We are
# effectively drawing from as many different distributions as there are
# individuals. Thankfully the same conditional covariance matrix can be
# used for all the distributions, only the mean vector needs to be
# different. We draw the underlying samples from the MVN at mean 0, and
# offset that later on a per-individual basis.
#
# To work over all the vectors directly they need to be as columns, which
# is why we start by transposing `given`. R will recycle the `m` matrix and
# `mean` vectors across all the columns. The last step is to transpose the
# result back into the expected configuration.

m <- sigma12 %*% solve(sigma22)
residual <- t(given) - mean[given.ind]
cond_mu <- t(m %*% residual + mean[dependent.ind])
cond_sigma <- sigma11 - m %*% sigma21

samples <- mvrnorm(n, rep(0, length(dependent.ind)), cond_sigma)
samples + cond_mu
}
}
18 changes: 8 additions & 10 deletions R/events.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
create_events <- function(parameters) {
events <- list(
# MDA events
mda_administer = individual::Event$new(),
smc_administer = individual::Event$new(),
mda_administer = individual::Event$new(restore=FALSE),
smc_administer = individual::Event$new(restore=FALSE),

# TBV event
tbv_vaccination = individual::Event$new(),
tbv_vaccination = individual::Event$new(restore=FALSE),

# Bednet events
throw_away_net = individual::TargetedEvent$new(parameters$human_population)
Expand All @@ -21,7 +21,7 @@ create_events <- function(parameters) {
seq_along(parameters$mass_pev_booster_spacing),
function(.) individual::TargetedEvent$new(parameters$human_population)
)
events$mass_pev <- individual::Event$new()
events$mass_pev <- individual::Event$new(restore=FALSE)
events$mass_pev_doses <- mass_pev_doses
events$mass_pev_boosters <- mass_pev_boosters
}
Expand Down Expand Up @@ -63,16 +63,16 @@ initialise_events <- function(events, variables, parameters) {

# Initialise scheduled interventions
if (!is.null(parameters$mass_pev_timesteps)) {
events$mass_pev$schedule(parameters$mass_pev_timesteps[[1]] - 1)
events$mass_pev$schedule(parameters$mass_pev_timesteps - 1)
}
if (parameters$mda) {
events$mda_administer$schedule(parameters$mda_timesteps[[1]] - 1)
events$mda_administer$schedule(parameters$mda_timesteps - 1)
}
if (parameters$smc) {
events$smc_administer$schedule(parameters$smc_timesteps[[1]] - 1)
events$smc_administer$schedule(parameters$smc_timesteps - 1)
}
if (parameters$tbv) {
events$tbv_vaccination$schedule(parameters$tbv_timesteps[[1]] - 1)
events$tbv_vaccination$schedule(parameters$tbv_timesteps - 1)
}
}

Expand Down Expand Up @@ -158,7 +158,6 @@ attach_event_listeners <- function(
if (parameters$mda == 1) {
events$mda_administer$add_listener(create_mda_listeners(
variables,
events$mda_administer,
parameters$mda_drug,
parameters$mda_timesteps,
parameters$mda_coverages,
Expand All @@ -174,7 +173,6 @@ attach_event_listeners <- function(
if (parameters$smc == 1) {
events$smc_administer$add_listener(create_mda_listeners(
variables,
events$smc_administer,
parameters$smc_drug,
parameters$smc_timesteps,
parameters$smc_coverages,
Expand Down
2 changes: 1 addition & 1 deletion R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ LaggedValue <- R6::R6Class(
timeseries_save_state(private$history)
},

restore_state = function(state) {
restore_state = function(t, state) {
timeseries_restore_state(private$history, state)
}
)
Expand Down
7 changes: 0 additions & 7 deletions R/mda_processes.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#' @title Create listeners for MDA events
#' @param variables the variables available in the model
#' @param administer_event the event schedule for drug administration
#' @param drug the drug to administer
#' @param timesteps timesteps for each round
#' @param coverages the coverage for each round
Expand All @@ -14,7 +13,6 @@
#' @noRd
create_mda_listeners <- function(
variables,
administer_event,
drug,
timesteps,
coverages,
Expand Down Expand Up @@ -78,11 +76,6 @@ create_mda_listeners <- function(
variables$drug$queue_update(drug, to_move)
variables$drug_time$queue_update(timestep, to_move)
}

# Schedule next round
if (time_index < length(timesteps)) {
administer_event$schedule(timesteps[[time_index + 1]] - timestep)
}
}
}

Expand Down
Loading
Loading