Skip to content

Commit

Permalink
Merge pull request #1 from mrc-ide/mrc-5075
Browse files Browse the repository at this point in the history
Basic sampler implementation
  • Loading branch information
richfitz authored Feb 21, 2024
2 parents b1aeeb4 + 3919f38 commit 89b17a8
Show file tree
Hide file tree
Showing 19 changed files with 525 additions and 3 deletions.
1 change: 1 addition & 0 deletions .lintr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
linters: linters_with_defaults(
indentation_linter = NULL,
object_length_linter = NULL,
object_usage_linter = NULL,
cyclocomp_linter = NULL
Expand Down
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
Description: Experimental sources for the next generation of mcstate,
which will support much of the old mcstate functionality but new
things like better parameter interfaces, Hamiltonian Monte Carlo,
etc.
and other features.
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.1
RoxygenNote: 7.3.1
URL: https://github.com/mrc-ide/mcstate2
BugReports: https://github.com/mrc-ide/mcstate2/issues
Imports:
cli
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Language: en-GB
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# Generated by roxygen2: do not edit by hand

export(mcstate_model)
export(mcstate_sample)
export(mcstate_sampler_random_walk)
importFrom(stats,rnorm)
importFrom(stats,runif)
15 changes: 15 additions & 0 deletions R/distributions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## Not in base R
rmvnorm <- function(x, vcv) {
make_rmvnorm(vcv)(x)
}


##' @importFrom stats rnorm
make_rmvnorm <- function(vcv) {
n <- ncol(vcv)
r <- chol(vcv, pivot = TRUE)
r <- r[, order(attr(r, "pivot", exact = TRUE))]
function(x) {
x + drop(rnorm(n) %*% r)
}
}
63 changes: 63 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
##' Create a basic `mcstate` model. Currently nothing here is
##' validated, and it's likely that users will never actually use this
##' directly. Contains data and methods that define a basic model
##' object, so that we can implement samplers against. Not all models
##' will support everything here, and we'll add additional
##' fields/traits over time to advertise what a model can do. For
##' example, models will need to advertise that they are capable of
##' being differentiated, or that they are stochastic in order to be
##' used with different methods.
##'
##' @title Create basic model
##'
##' @param parameters Names of the parameters. Every parameter is
##' named, and for now every parameter is a scalar. We might relax
##' this later to support an `odin`-style structured parameter list,
##' but that might just generate a suitable vector of parameter
##' names perhaps? In any case, once we start doing inference it's
##' naturally in the R^n, and here n is defined as the length of
##' this vector of names.
##'
##' @param direct_sample A function to sample directly from the
##' parameter space. In the case where a model returns a posterior
##' (e.g., in Bayesian inference), this is assumed to be sampling
##' from the prior. We'll use this for generating initial
##' conditions for MCMC where those are not given, and possibly
##' other uses.
##'
##' @param density Compute the model density for a vector of parameter
##' values; this is the posterior probability in the case of
##' Bayesian inference, but it could be anything really. Models can
##' return `-Inf` if things are impossible, and we'll try and cope
##' gracefully with that wherever possible.
##'
##' @param gradient Compute the gradient of `density` with respect to
##' the parameter vector; takes a parameter vector and returns a
##' vector the same length. For efficiency, the model may want to
##' be stateful so that gradients can be efficiently calculated
##' after a density calculation, or density after gradient, where
##' these are called with the same parameters.
##'
##' @param domain Information on the parameter domain. This is a two
##' column matrix with `length(parameters)` rows representing each
##' parameter. The parameter minimum and maximum bounds are given
##' as the first and second column. Infinite values (`-Inf` or
##' `Inf`) should be used where the parameter has infinite domain up
##' or down. Currently used to translate from a bounded to
##' unbounded space for HMC, but we might also use this for
##' reflecting proposals in MCMC too.
##'
##' @return An object of class `mcstate_model`, which can be used with
##' a sampler.
##'
##' @export
mcstate_model <- function(parameters, direct_sample, density, gradient,
domain) {
ret <- list(parameters = parameters,
direct_sample = direct_sample,
density = density,
gradient = gradient,
domain = domain)
class(ret) <- "mcstate_model"
ret
}
86 changes: 86 additions & 0 deletions R/sample.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
##' Sample from a model. Uses a Monte Carlo method (or possibly
##' something else in future) to generate samples from your
##' distribution. This is going to change a lot in future, as we add
##' support for distributing over workers, and for things like
##' parallel reproducible streams of random numbers. For now it just
##' runs a single chain as a proof of concept.
##'
##' @title Sample from a model
##'
##' @param model The model to sample from; this should be a
##' `mcstate_model` for now, but we might change this in future to
##' test to see if things match an interface rather than a
##' particular class attribute.
##'
##' @param sampler A sampler to use. These will be described later,
##' but we hope to make these reasonably easy to implement so that
##' we can try out different sampling ideas. For now, the only
##' sampler implemented is [mcstate_sampler_random_walk()].
##'
##' @param n_steps The number of steps to run the sampler for.
##'
##' @param initial Optionally, initial parameter values for the
##' sampling. If not given, we sample from the model (or its prior).
##'
##' @return A list of parameters and densities.
##'
##' @export
mcstate_sample <- function(model, sampler, n_steps, initial = NULL) {
if (!inherits(model, "mcstate_model")) {
cli::cli_abort("Expected 'model' to be an 'mcstate_model'",
arg = "model")
}
if (!inherits(sampler, "mcstate_sampler")) {
cli::cli_abort("Expected 'sampler' to be an 'mcstate_sampler'",
arg = "sampler")
}

if (is.null(initial)) {
## Really this would just be from the prior; we can't directly
## sample from the posterior!
pars <- model$direct_sample()
} else {
pars <- initial
if (length(pars) != length(model$parameters)) {
cli::cli_abort(
paste("Unexpected initial parameter length {length(pars)};",
"expected {length(model$parameters)}"),
arg = "initial")
}
}

density <- model$density(pars)
state <- list(pars = pars, density = density)
sampler$initialise(state, model)

history_pars <- matrix(NA_real_, n_steps + 1, length(pars))
history_pars[1, ] <- pars
history_density <- rep(NA_real_, n_steps + 1)
history_density[[1]] <- density

for (i in seq_len(n_steps)) {
state <- sampler$step(state, model)
history_pars[i + 1, ] <- state$pars
history_density[[i + 1]] <- state$density
}

## Pop the parameter names on last
colnames(history_pars) <- model$parameters

## I'm not sure about the best name for this
details <- sampler$finalise(state, model)

list(pars = history_pars,
density = history_density,
details = details)
}


mcstate_sampler <- function(name, initialise, step, finalise) {
ret <- list(name = name,
initialise = initialise,
step = step,
finalise = finalise)
class(ret) <- "mcstate_sampler"
ret
}
60 changes: 60 additions & 0 deletions R/sampler-random-walk.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
##' Create a simple random walk sampler, which uses a symmetric
##' proposal to move around parameter space.
##'
##' @title Random Walk Sampler
##'
##' @param proposal A proposal function; must take a vector of
##' parameters and produce a new vector of proposed parameters.
##'
##' @param vcv A variance covariance matrix to generate a `proposal`
##' function from. If you want multivariate Gaussian proposal, this
##' is likely simpler than supplying your own `proposal`, and
##' generally more efficient too.
##'
##' @return A `mcstate_sampler` object, which can be used with
##' [mcstate_sample]
##'
##' @importFrom stats runif
##' @export
mcstate_sampler_random_walk <- function(proposal = NULL, vcv = NULL) {
if (is.null(proposal) && is.null(vcv)) {
cli::cli_abort("One of 'proposal' or 'vcv' must be given")
}
if (!is.null(proposal) && !is.null(vcv)) {
cli::cli_abort("Only one of 'proposal' or 'vcv' may be given")
}
if (!is.null(vcv)) { # proposal is null
check_vcv(vcv, call = environment())
proposal <- make_rmvnorm(vcv)
}

initialise <- function(state, model) {
if (!is.null(vcv)) {
n_pars <- length(state$pars)
n_vcv <- nrow(vcv)
if (n_pars != n_vcv) {
cli::cli_abort(
"Incompatible length parameters ({n_pars}) and vcv ({n_vcv})")
}
}
}

step <- function(state, model) {
pars_next <- proposal(state$pars)
density_next <- model$density(pars_next)
if (density_next - state$density > log(runif(1))) {
state$pars <- pars_next
state$density <- density_next
}
state
}

finalise <- function(state, model) {
NULL
}

mcstate_sampler("Random walk",
initialise,
step,
finalise)
}
22 changes: 22 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
`%||%` <- function(x, y) { # nolint
if (is.null(x)) y else x
}


check_vcv <- function(vcv, name = deparse(substitute(vcv)), call = NULL) {
if (!is.matrix(vcv)) {
cli::cli_abort("Expected '{name}' to be a matrix",
arg = name, call = call)
}
if (!isSymmetric(vcv)) {
cli::cli_abort("Expected '{name}' to be symmetric",
arg = name, call = call)
}
if (!is_positive_definite(vcv)) {
cli::cli_abort("Expected '{name}' to be positive definite",
arg = name, call = call)
}
}


is_positive_definite <- function(x, tol = sqrt(.Machine$double.eps)) {
ev <- eigen(x, symmetric = TRUE)
all(ev$values >= -tol * abs(ev$values[1]))
}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<!-- badges: start -->
[![Project Status: Concept – Minimal or no implementation has been done yet, or the repository is only intended to be a limited example, demo, or proof-of-concept.](https://www.repostatus.org/badges/latest/concept.svg)](https://www.repostatus.org/#concept)
[![R build status](https://github.com/mrc-ide/mcstate2/workflows/R-CMD-check/badge.svg)](https://github.com/mrc-ide/mcstate2/actions)
[![R-CMD-check](https://github.com/mrc-ide/mcstate2/actions/workflows/R-CMD-check.yaml/badge.svg?branch=main)](https://github.com/mrc-ide/mcstate2/actions/workflows/R-CMD-check.yaml)
[![codecov.io](https://codecov.io/github/mrc-ide/mcstate2/coverage.svg?branch=main)](https://codecov.io/github/mrc-ide/mcstate2?branch=main)
<!-- badges: end -->

Expand Down
5 changes: 5 additions & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CMD
HMC
codecov
io
mcstate
61 changes: 61 additions & 0 deletions man/mcstate_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions man/mcstate_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 89b17a8

Please sign in to comment.