Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for stochastic models #9

Merged
merged 8 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ LinkingTo:
cpp11
Suggests:
cpp11,
decor,
decor,
dust,
mvtnorm,
numDeriv,
pkgload,
testthat (>= 3.0.0)
Config/testthat/edition: 3
Language: en-GB
Remotes:
mrc-ide/dust
61 changes: 58 additions & 3 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
##' `direct_sample` method. Use `NULL` (the default) to detect this
##' from the model.
##'
##' @param is_stochastic Logical, indicating if the model is
##' stochastic. Stochastic models must supply a `set_rng_state`
##' method and we might support a `get_rng_state` method later.
##'
##' @return A list of class `mcstate_model_properties` which should
##' not be modified.
##'
##' @export
mcstate_model_properties <- function(has_gradient = NULL,
has_direct_sample = NULL) {
has_direct_sample = NULL,
is_stochastic = NULL) {
ret <- list(has_gradient = has_gradient,
has_direct_sample = has_direct_sample)
has_direct_sample = has_direct_sample,
is_stochastic = is_stochastic)
class(ret) <- "mcstate_model_properties"
ret
}
Expand Down Expand Up @@ -76,6 +82,27 @@ mcstate_model_properties <- function(has_gradient = NULL,
##' function is optional (and may not be well defined or possible to
##' define).
##'
##' * `set_rng_state`: A function to set the state (this is in
##' contrast to the `rng` that is passed through to `direct_sample`
##' as that is the _sampler's_ rng stream, but we assume models will
##' look after their own stream, and that they may need many
##' streams). Models that provide this method are assumed to be
##' stochastic; however, you can use the `is_stochastic` property
##' (via [mcstate_model_properties()]) to override this (e.g., to
##' run a stochastic model with its deterministic expectation).
##' This function takes an [mcstate_rng] object and uses it to seed
##' the random number state for your model. You have two options
##' here (1) hold a copy of the provided object and draw samples
##' from it as needed (in effect sharing the random number stream
##' with the sampler) or create a new rng stream from a jump with
##' this stream (we'll provide a utility for doing this but at
##' present doing `mcstate_rng$new(rng$state(),
##' n_streams)$jump()$state()` will do). The main reason you'd do
##' that is if you need multiple (perhaps parallel) streams of
##' random numbers in your model. The `$jump()` is very important,
##' otherwise you'll end up correlated with the draws from the
##' sampler.
##'
##' @title Create basic model
##'
##' @param model A list or environment with elements as described in
Expand All @@ -95,6 +122,7 @@ mcstate_model_properties <- function(has_gradient = NULL,
##' see [mcstate_model_properties()]. Currently this contains:
##' * `has_gradient`: the model can compute its gradient
##' * `has_direct_sample`: the model can sample from parameters space
##' * `is_stochastic`: the model will behave stochastically
##'
##' @export
mcstate_model <- function(model, properties = NULL) {
Expand All @@ -106,10 +134,12 @@ mcstate_model <- function(model, properties = NULL) {
properties <- validate_model_properties(properties, call)
gradient <- validate_model_gradient(model, properties, call)
direct_sample <- validate_model_direct_sample(model, properties, call)
rng_state <- validate_model_rng_state(model, properties, call)

## Update properties based on what we found:
properties$has_gradient <- !is.null(gradient)
properties$has_direct_sample <- !is.null(direct_sample)
properties$is_stochastic <- !is.null(rng_state$set)

ret <- list(model = model,
parameters = parameters,
Expand All @@ -123,7 +153,6 @@ mcstate_model <- function(model, properties = NULL) {
}



validate_model_properties <- function(properties, call = NULL) {
if (is.null(properties)) {
return(mcstate_model_properties())
Expand Down Expand Up @@ -215,6 +244,32 @@ validate_model_direct_sample <- function(model, properties, call) {
}


validate_model_rng_state <- function(model, properties, call) {
if (isFALSE(properties$is_stochastic)) {
return(NULL)
}
if (is.null(properties$is_stochastic) && is.null(model$set_rng_state)) {
return(NULL)
}
if (!is.function(model$set_rng_state)) {
if (isTRUE(properties$is_stochastic)) {
hint <- paste("You have specified 'is_stochastic = TRUE', so in order",
"to use your stochastic model we need a way of setting",
"its state")
} else {
hint <- paste("I found a non-function element 'set_rng_state' within",
"your model and you have not set the 'is_stochastic'",
"property")
}
cli::cli_abort(
c("Expected 'model$set_rng_state' to be a function",
i = hint),
arg = "model", call = call)
}
list(set = model$set_rng_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for clarification, why is this returned as list(set = model$set_rng_state) instead of returning model$set_rng_state?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I anticipate that later we'll need get = model$get_rng_state too for cases where we have a stochastic model that we want to be able to restart

}


require_direct_sample <- function(model, message, ...) {
if (!model$properties$has_direct_sample) {
cli::cli_abort(
Expand Down
7 changes: 6 additions & 1 deletion R/sampler-random-walk.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
##' Create a simple random walk sampler, which uses a symmetric
##' proposal to move around parameter space.
##' proposal to move around parameter space. This sampler supports
##' sampling from models where the likelihood is only computable
##' randomly (e.g., for pmcmc).
##'
##' @title Random Walk Sampler
##'
Expand Down Expand Up @@ -38,6 +40,9 @@ mcstate_sampler_random_walk <- function(proposal = NULL, vcv = NULL) {
"Incompatible length parameters ({n_pars}) and vcv ({n_vcv})")
}
}
if (isTRUE(model$properties$is_stochastic)) {
model$model$set_rng_state(rng)
}
Comment on lines +43 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be applied to different samplers? if so perhaps we should put this in mcstate_run_chain right after sampler$initialise so that this is a general functionality instead of tied to this particular initialisation function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the other samplers can cope with stochastic models; they will be throwing errors!

}

step <- function(state, model, rng) {
Expand Down
20 changes: 20 additions & 0 deletions man/mcstate_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/mcstate_model_properties.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/mcstate_sampler_random_walk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 62 additions & 0 deletions tests/testthat/helper-mcstate2.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,65 @@ ex_simple_gamma1 <- function(shape = 1, rate = 1) {
gradient = function(x) (shape - 1) / x - rate,
domain = rbind(c(0, Inf)))))
}


ex_dust_sir <- function(n_particles = 100, n_threads = 1,
deterministic = FALSE) {
testthat::skip_if_not_installed("dust")
sir <- dust::dust_example("sir")

np <- 10
end <- 150 * 4
times <- seq(0, end, by = 4)
ans <- sir$new(list(), 0, np, seed = 1L)$simulate(times)
dat <- data.frame(time = times[-1], incidence = ans[5, 1, -1])

## TODO: an upshot here is that our dust models are always going to
## need to be initialisable; we might need to sample from the
## statistical parameters, or set things up to allow two-phases of
## initialsation (which is I think where we are heading, so that's
## fine).
model <- sir$new(list(), 0, n_particles, seed = 1L, n_threads = n_threads,
deterministic = deterministic)
model$set_data(dust::dust_data(dat))

prior_beta_shape <- 1
prior_beta_rate <- 1 / 0.5
prior_gamma_shape <- 1
prior_gamma_rate <- 1 / 0.5

density <- function(x) {
beta <- x[[1]]
gamma <- x[[2]]
prior <- dgamma(beta, prior_beta_shape, prior_beta_rate, log = TRUE) +
dgamma(gamma, prior_gamma_shape, prior_gamma_rate, log = TRUE)
if (is.finite(prior)) {
model$update_state(
pars = list(beta = x[[1]], gamma = x[[2]]),
time = 0,
set_initial_state = TRUE)
ll <- model$filter()$log_likelihood
} else {
ll <- -Inf
}
ll + prior
}

direct_sample <- function(rng) {
c(rng$gamma(1, prior_beta_shape, 1 / prior_beta_rate),
rng$gamma(1, prior_gamma_shape, 1 / prior_gamma_rate))
}

set_rng_state <- function(rng) {
state <- mcstate_rng$new(rng$state(), n_particles + 1)$jump()$state()
model$set_rng_state(state)
}

mcstate_model(
list(density = density,
direct_sample = direct_sample,
parameters = c("beta", "gamma"),
domain = cbind(c(0, 0), c(Inf, Inf)),
set_rng_state = set_rng_state),
mcstate_model_properties(is_stochastic = !deterministic))
}
32 changes: 30 additions & 2 deletions tests/testthat/test-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ test_that("can create a minimal model", {
expect_s3_class(m, "mcstate_model")
expect_equal(m$properties,
mcstate_model_properties(has_gradient = FALSE,
has_direct_sample = FALSE))
has_direct_sample = FALSE,
is_stochastic = FALSE))
expect_equal(m$domain, cbind(-Inf, Inf))
expect_equal(m$parameters, "a")
expect_equal(m$density(0), dnorm(0, log = TRUE))
Expand All @@ -15,7 +16,8 @@ test_that("can create a more interesting model", {
m <- ex_simple_gamma1()
expect_equal(m$properties,
mcstate_model_properties(has_gradient = TRUE,
has_direct_sample = TRUE))
has_direct_sample = TRUE,
is_stochastic = FALSE))
expect_equal(m$domain, cbind(0, Inf))
expect_equal(m$parameters, "gamma")
expect_equal(m$density(1), dgamma(1, 1, 1, log = TRUE))
Expand Down Expand Up @@ -117,3 +119,29 @@ test_that("require properties are correct type", {
list(has_gradient = FALSE)),
"Expected 'properties' to be a 'mcstate_model_properties' object")
})


test_that("stochastic models need an rng setting function", {
expect_error(
mcstate_model(
list(density = identity, parameters = "a"),
mcstate_model_properties(is_stochastic = TRUE)),
"Expected 'model$set_rng_state' to be a function",
fixed = TRUE)
expect_error(
mcstate_model(
list(density = identity, parameters = "a", set_rng_state = TRUE)),
"Expected 'model$set_rng_state' to be a function",
fixed = TRUE)
expect_error(
mcstate_model(
list(density = identity, parameters = "a", set_rng_state = TRUE),
mcstate_model_properties(is_stochastic = TRUE)),
"Expected 'model$set_rng_state' to be a function",
fixed = TRUE)
expect_no_error(
res <- mcstate_model(
list(density = identity, parameters = "a", set_rng_state = TRUE),
mcstate_model_properties(is_stochastic = FALSE)))
expect_null(res$set_rng_state)
})
10 changes: 10 additions & 0 deletions tests/testthat/test-sampler-random-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ test_that("validate sampler against model on initialisation", {
"Incompatible length parameters (1) and vcv (2)",
fixed = TRUE)
})


test_that("can draw samples from a random model", {
set.seed(1)
m <- ex_dust_sir()
vcv <- matrix(c(0.0006405, 0.0005628, 0.0005628, 0.0006641), 2, 2)
sampler <- mcstate_sampler_random_walk(vcv = vcv)
res <- mcstate_sample(m, sampler, 20)
expect_setequal(names(res), c("pars", "density", "details", "chain"))
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not understood this comment well yet!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, it's out of date now after #10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and now it's gone)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to definitely chat through this and that sir model on call!

Loading