-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from mrc-ide/mrc-5075
Basic sampler implementation
- Loading branch information
Showing
19 changed files
with
525 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(mcstate_model) | ||
export(mcstate_sample) | ||
export(mcstate_sampler_random_walk) | ||
importFrom(stats,rnorm) | ||
importFrom(stats,runif) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
## Not in base R | ||
rmvnorm <- function(x, vcv) { | ||
make_rmvnorm(vcv)(x) | ||
} | ||
|
||
|
||
##' @importFrom stats rnorm | ||
make_rmvnorm <- function(vcv) { | ||
n <- ncol(vcv) | ||
r <- chol(vcv, pivot = TRUE) | ||
r <- r[, order(attr(r, "pivot", exact = TRUE))] | ||
function(x) { | ||
x + drop(rnorm(n) %*% r) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
##' Create a basic `mcstate` model. Currently nothing here is | ||
##' validated, and it's likely that users will never actually use this | ||
##' directly. Contains data and methods that define a basic model | ||
##' object, so that we can implement samplers against. Not all models | ||
##' will support everything here, and we'll add additional | ||
##' fields/traits over time to advertise what a model can do. For | ||
##' example, models will need to advertise that they are capable of | ||
##' being differentiated, or that they are stochastic in order to be | ||
##' used with different methods. | ||
##' | ||
##' @title Create basic model | ||
##' | ||
##' @param parameters Names of the parameters. Every parameter is | ||
##' named, and for now every parameter is a scalar. We might relax | ||
##' this later to support an `odin`-style structured parameter list, | ||
##' but that might just generate a suitable vector of parameter | ||
##' names perhaps? In any case, once we start doing inference it's | ||
##' naturally in the R^n, and here n is defined as the length of | ||
##' this vector of names. | ||
##' | ||
##' @param direct_sample A function to sample directly from the | ||
##' parameter space. In the case where a model returns a posterior | ||
##' (e.g., in Bayesian inference), this is assumed to be sampling | ||
##' from the prior. We'll use this for generating initial | ||
##' conditions for MCMC where those are not given, and possibly | ||
##' other uses. | ||
##' | ||
##' @param density Compute the model density for a vector of parameter | ||
##' values; this is the posterior probability in the case of | ||
##' Bayesian inference, but it could be anything really. Models can | ||
##' return `-Inf` if things are impossible, and we'll try and cope | ||
##' gracefully with that wherever possible. | ||
##' | ||
##' @param gradient Compute the gradient of `density` with respect to | ||
##' the parameter vector; takes a parameter vector and returns a | ||
##' vector the same length. For efficiency, the model may want to | ||
##' be stateful so that gradients can be efficiently calculated | ||
##' after a density calculation, or density after gradient, where | ||
##' these are called with the same parameters. | ||
##' | ||
##' @param domain Information on the parameter domain. This is a two | ||
##' column matrix with `length(parameters)` rows representing each | ||
##' parameter. The parameter minimum and maximum bounds are given | ||
##' as the first and second column. Infinite values (`-Inf` or | ||
##' `Inf`) should be used where the parameter has infinite domain up | ||
##' or down. Currently used to translate from a bounded to | ||
##' unbounded space for HMC, but we might also use this for | ||
##' reflecting proposals in MCMC too. | ||
##' | ||
##' @return An object of class `mcstate_model`, which can be used with | ||
##' a sampler. | ||
##' | ||
##' @export | ||
mcstate_model <- function(parameters, direct_sample, density, gradient, | ||
domain) { | ||
ret <- list(parameters = parameters, | ||
direct_sample = direct_sample, | ||
density = density, | ||
gradient = gradient, | ||
domain = domain) | ||
class(ret) <- "mcstate_model" | ||
ret | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
##' Sample from a model. Uses a Monte Carlo method (or possibly | ||
##' something else in future) to generate samples from your | ||
##' distribution. This is going to change a lot in future, as we add | ||
##' support for distributing over workers, and for things like | ||
##' parallel reproducible streams of random numbers. For now it just | ||
##' runs a single chain as a proof of concept. | ||
##' | ||
##' @title Sample from a model | ||
##' | ||
##' @param model The model to sample from; this should be a | ||
##' `mcstate_model` for now, but we might change this in future to | ||
##' test to see if things match an interface rather than a | ||
##' particular class attribute. | ||
##' | ||
##' @param sampler A sampler to use. These will be described later, | ||
##' but we hope to make these reasonably easy to implement so that | ||
##' we can try out different sampling ideas. For now, the only | ||
##' sampler implemented is [mcstate_sampler_random_walk()]. | ||
##' | ||
##' @param n_steps The number of steps to run the sampler for. | ||
##' | ||
##' @param initial Optionally, initial parameter values for the | ||
##' sampling. If not given, we sample from the model (or its prior). | ||
##' | ||
##' @return A list of parameters and densities. | ||
##' | ||
##' @export | ||
mcstate_sample <- function(model, sampler, n_steps, initial = NULL) { | ||
if (!inherits(model, "mcstate_model")) { | ||
cli::cli_abort("Expected 'model' to be an 'mcstate_model'", | ||
arg = "model") | ||
} | ||
if (!inherits(sampler, "mcstate_sampler")) { | ||
cli::cli_abort("Expected 'sampler' to be an 'mcstate_sampler'", | ||
arg = "sampler") | ||
} | ||
|
||
if (is.null(initial)) { | ||
## Really this would just be from the prior; we can't directly | ||
## sample from the posterior! | ||
pars <- model$direct_sample() | ||
} else { | ||
pars <- initial | ||
if (length(pars) != length(model$parameters)) { | ||
cli::cli_abort( | ||
paste("Unexpected initial parameter length {length(pars)};", | ||
"expected {length(model$parameters)}"), | ||
arg = "initial") | ||
} | ||
} | ||
|
||
density <- model$density(pars) | ||
state <- list(pars = pars, density = density) | ||
sampler$initialise(state, model) | ||
|
||
history_pars <- matrix(NA_real_, n_steps + 1, length(pars)) | ||
history_pars[1, ] <- pars | ||
history_density <- rep(NA_real_, n_steps + 1) | ||
history_density[[1]] <- density | ||
|
||
for (i in seq_len(n_steps)) { | ||
state <- sampler$step(state, model) | ||
history_pars[i + 1, ] <- state$pars | ||
history_density[[i + 1]] <- state$density | ||
} | ||
|
||
## Pop the parameter names on last | ||
colnames(history_pars) <- model$parameters | ||
|
||
## I'm not sure about the best name for this | ||
details <- sampler$finalise(state, model) | ||
|
||
list(pars = history_pars, | ||
density = history_density, | ||
details = details) | ||
} | ||
|
||
|
||
mcstate_sampler <- function(name, initialise, step, finalise) { | ||
ret <- list(name = name, | ||
initialise = initialise, | ||
step = step, | ||
finalise = finalise) | ||
class(ret) <- "mcstate_sampler" | ||
ret | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
##' Create a simple random walk sampler, which uses a symmetric | ||
##' proposal to move around parameter space. | ||
##' | ||
##' @title Random Walk Sampler | ||
##' | ||
##' @param proposal A proposal function; must take a vector of | ||
##' parameters and produce a new vector of proposed parameters. | ||
##' | ||
##' @param vcv A variance covariance matrix to generate a `proposal` | ||
##' function from. If you want multivariate Gaussian proposal, this | ||
##' is likely simpler than supplying your own `proposal`, and | ||
##' generally more efficient too. | ||
##' | ||
##' @return A `mcstate_sampler` object, which can be used with | ||
##' [mcstate_sample] | ||
##' | ||
##' @importFrom stats runif | ||
##' @export | ||
mcstate_sampler_random_walk <- function(proposal = NULL, vcv = NULL) { | ||
if (is.null(proposal) && is.null(vcv)) { | ||
cli::cli_abort("One of 'proposal' or 'vcv' must be given") | ||
} | ||
if (!is.null(proposal) && !is.null(vcv)) { | ||
cli::cli_abort("Only one of 'proposal' or 'vcv' may be given") | ||
} | ||
if (!is.null(vcv)) { # proposal is null | ||
check_vcv(vcv, call = environment()) | ||
proposal <- make_rmvnorm(vcv) | ||
} | ||
|
||
initialise <- function(state, model) { | ||
if (!is.null(vcv)) { | ||
n_pars <- length(state$pars) | ||
n_vcv <- nrow(vcv) | ||
if (n_pars != n_vcv) { | ||
cli::cli_abort( | ||
"Incompatible length parameters ({n_pars}) and vcv ({n_vcv})") | ||
} | ||
} | ||
} | ||
|
||
step <- function(state, model) { | ||
pars_next <- proposal(state$pars) | ||
density_next <- model$density(pars_next) | ||
if (density_next - state$density > log(runif(1))) { | ||
state$pars <- pars_next | ||
state$density <- density_next | ||
} | ||
state | ||
} | ||
|
||
finalise <- function(state, model) { | ||
NULL | ||
} | ||
|
||
mcstate_sampler("Random walk", | ||
initialise, | ||
step, | ||
finalise) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,25 @@ | ||
`%||%` <- function(x, y) { # nolint | ||
if (is.null(x)) y else x | ||
} | ||
|
||
|
||
check_vcv <- function(vcv, name = deparse(substitute(vcv)), call = NULL) { | ||
if (!is.matrix(vcv)) { | ||
cli::cli_abort("Expected '{name}' to be a matrix", | ||
arg = name, call = call) | ||
} | ||
if (!isSymmetric(vcv)) { | ||
cli::cli_abort("Expected '{name}' to be symmetric", | ||
arg = name, call = call) | ||
} | ||
if (!is_positive_definite(vcv)) { | ||
cli::cli_abort("Expected '{name}' to be positive definite", | ||
arg = name, call = call) | ||
} | ||
} | ||
|
||
|
||
is_positive_definite <- function(x, tol = sqrt(.Machine$double.eps)) { | ||
ev <- eigen(x, symmetric = TRUE) | ||
all(ev$values >= -tol * abs(ev$values[1])) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
CMD | ||
HMC | ||
codecov | ||
io | ||
mcstate |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.