From 955a5a059606ba822c4026efa5b7f351c0b0a0d2 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 10:24:00 +0000 Subject: [PATCH 1/7] Make pointer access safe --- src/random.cpp | 53 +++++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/random.cpp b/src/random.cpp index 69d60524..14c92c92 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -15,6 +15,15 @@ using default_rng = monty::random::prng>; +template +T* safely_read_externalptr(SEXP ptr, const char * context) { + if (!R_ExternalPtrAddr(ptr)) { + cpp11::stop("Pointer has been serialised, cannot continue safely (%s)", + context); + } + return cpp11::as_cpp>(ptr).get(); +} + template SEXP monty_rng_alloc(cpp11::sexp r_seed, int n_streams, bool deterministic) { auto seed = monty::random::r::as_rng_seed(r_seed); @@ -24,13 +33,13 @@ SEXP monty_rng_alloc(cpp11::sexp r_seed, int n_streams, bool deterministic) { template void monty_rng_jump(SEXP ptr) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "jump"); rng->jump(); } template void monty_rng_long_jump(SEXP ptr) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "long_jump"); rng->long_jump(); } @@ -45,7 +54,7 @@ cpp11::sexp sexp_matrix(cpp11::sexp x, int n, int m) { template cpp11::sexp monty_rng_random_real(SEXP ptr, int n, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "random_real"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); @@ -67,7 +76,7 @@ cpp11::sexp monty_rng_random_real(SEXP ptr, int n, int n_threads) { template cpp11::sexp monty_rng_random_normal(SEXP ptr, int n, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "random_normal"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); @@ -178,7 +187,7 @@ cpp11::sexp monty_rng_uniform(SEXP ptr, int n, cpp11::doubles r_min, cpp11::doubles r_max, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "uniform"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -209,7 +218,7 @@ cpp11::sexp monty_rng_uniform(SEXP ptr, int n, template cpp11::sexp monty_rng_exponential_rate(SEXP ptr, int n, cpp11::doubles r_rate, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "exponential_rate"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -237,7 +246,7 @@ cpp11::sexp monty_rng_exponential_rate(SEXP ptr, int n, cpp11::doubles r_rate, template cpp11::sexp monty_rng_exponential_mean(SEXP ptr, int n, cpp11::doubles r_mean, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "exponential_mean"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -265,7 +274,7 @@ template cpp11::sexp monty_rng_normal(SEXP ptr, int n, cpp11::doubles r_mean, cpp11::doubles r_sd, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "normal"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -297,7 +306,7 @@ template cpp11::sexp monty_rng_binomial(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "binomial"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -339,7 +348,7 @@ cpp11::sexp monty_rng_beta_binomial_prob(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, cpp11::doubles r_rho, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "beta_binomial_prob"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -385,7 +394,7 @@ cpp11::sexp monty_rng_beta_binomial_ab(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_a, cpp11::doubles r_b, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "beta_binomial_ab"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -430,7 +439,7 @@ template cpp11::sexp monty_rng_negative_binomial_prob(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "negative_binomial_prob"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -471,7 +480,7 @@ template cpp11::sexp monty_rng_negative_binomial_mu(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_mu, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "negative_binomial_mu"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -511,7 +520,7 @@ cpp11::sexp monty_rng_negative_binomial_mu(SEXP ptr, int n, template cpp11::sexp monty_rng_poisson(SEXP ptr, int n, cpp11::doubles r_lambda, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "poisson"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -549,7 +558,7 @@ cpp11::sexp monty_rng_multinomial(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "multinomial"); const int n_streams = rng->size(); const double * size = REAL(r_size); @@ -602,7 +611,7 @@ template cpp11::sexp monty_rng_hypergeometric(SEXP ptr, int n, cpp11::doubles r_n1, cpp11::doubles r_n2, cpp11::doubles r_k, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "hypergeometric"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -648,7 +657,7 @@ cpp11::sexp monty_rng_gamma_scale(SEXP ptr, int n, cpp11::doubles r_shape, cpp11::doubles r_scale, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "gamma_scale"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -689,7 +698,7 @@ cpp11::sexp monty_rng_gamma_rate(SEXP ptr, int n, cpp11::doubles r_shape, cpp11::doubles r_rate, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "gamma_rate"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -734,7 +743,7 @@ cpp11::sexp monty_rng_cauchy(SEXP ptr, int n, cpp11::doubles r_location, cpp11::doubles r_scale, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "cauchy"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -775,7 +784,7 @@ cpp11::sexp monty_rng_beta(SEXP ptr, int n, cpp11::doubles r_a, cpp11::doubles r_b, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "beta"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -818,7 +827,7 @@ cpp11::sexp monty_rng_truncated_normal(SEXP ptr, int n, cpp11::doubles r_min, cpp11::doubles r_max, int n_threads) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "truncated_normal"); const int n_streams = rng->size(); cpp11::writable::doubles ret = cpp11::writable::doubles(n * n_streams); double * y = REAL(ret); @@ -865,7 +874,7 @@ cpp11::sexp monty_rng_truncated_normal(SEXP ptr, int n, template cpp11::sexp monty_rng_state(SEXP ptr) { - T *rng = cpp11::as_cpp>(ptr).get(); + T *rng = safely_read_externalptr(ptr, "rng_state"); auto state = rng->export_state(); size_t len = sizeof(typename T::int_type) * state.size(); cpp11::writable::raws ret(len); From ca75b543c77891fceb49fbec7cf4da3caf377fdd Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 10:39:06 +0000 Subject: [PATCH 2/7] Add undocumented restore functionality --- R/combine.R | 13 +++++++++++++ R/model.R | 12 ++++++++++++ 2 files changed, 25 insertions(+) diff --git a/R/combine.R b/R/combine.R index 2796b071..e0822d63 100644 --- a/R/combine.R +++ b/R/combine.R @@ -119,6 +119,8 @@ monty_model_combine <- function(a, b, properties = NULL, parts, properties) observer <- model_combine_observer( parts, parameters, properties, call) + restore <- model_combine_restore( + parts) data <- list( parts = unname(parts), @@ -133,6 +135,7 @@ monty_model_combine <- function(a, b, properties = NULL, gradient = gradient, get_rng_state = stochastic$get_rng_state, set_rng_state = stochastic$set_rng_state, + restore = restore, observer = observer, direct_sample = direct_sample), properties) @@ -455,3 +458,13 @@ model_combine_allow_multiple_parameters <- function(parts, properties, "not supported by both of your models"), call = call) } + + +model_combine_restore <- function(parts) { + a <- parts[[1]] + b <- parts[[2]] + function() { + a$restore() + b$restore() + } +} diff --git a/R/model.R b/R/model.R index 507bf491..7d323671 100644 --- a/R/model.R +++ b/R/model.R @@ -199,6 +199,7 @@ monty_model <- function(model, properties = NULL) { observer <- validate_model_observer(model, properties, call) rng_state <- validate_model_rng_state(model, properties, call) parameter_groups <- validate_model_parameter_groups(model, properties, call) + restore <- validate_model_restore(model, properties, call) ## Update properties based on what we found: properties$has_gradient <- !is.null(gradient) @@ -217,6 +218,7 @@ monty_model <- function(model, properties = NULL) { gradient = gradient, direct_sample = direct_sample, observer = observer, + restore = restore, rng_state = rng_state, properties = properties) class(ret) <- "monty_model" @@ -539,6 +541,16 @@ validate_model_parameter_groups <- function(model, properties, call) { } +validate_model_restore <- function(model, propertioes, call) { + restore <- model$restore + if (is.function(restore)) { + restore + } else { + function() {} + } +} + + require_monty_model <- function(model, arg = deparse(substitute(model)), call = parent.frame()) { if (!inherits(model, "monty_model")) { From a70fe559a7f04d1dfa5fa94356e49ac223fa25af Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 10:46:05 +0000 Subject: [PATCH 3/7] Add tests --- tests/testthat/helper-monty.R | 30 +++++++++++++++++++++++++++ tests/testthat/test-combine.R | 25 ++++++++++++++++++++++ tests/testthat/test-model-serialise.R | 30 +++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 tests/testthat/test-model-serialise.R diff --git a/tests/testthat/helper-monty.R b/tests/testthat/helper-monty.R index 5308e6da..500d502a 100644 --- a/tests/testthat/helper-monty.R +++ b/tests/testthat/helper-monty.R @@ -112,6 +112,36 @@ ex_sir_filter_posterior <- function(...) { } +## A silly stochastic model: +ex_stochastic <- function(n = 10, sd_sample = 1, sd_measure = 1) { + env <- new.env() + env$rng <- monty_rng$new() + + get_rng_state <- function() { + env$rng$state() + } + set_rng_state <- function(rng_state) { + env$rng$state_set(rng_state) + } + density <- function(x) { + sum(dnorm(env$rng$normal(n, x, sd_sample), sd_measure, log = TRUE)) + } + + restore <- function() { + env$rng <- monty_rng$new() + } + + monty_model( + list(env = env, + density = density, + restore = restore, + parameters = "x", + set_rng_state = set_rng_state, + get_rng_state = get_rng_state), + monty_model_properties(is_stochastic = TRUE)) +} + + scrub_manual_info <- function(x) { x <- sub("Manual monty sampling at '.+", "Manual monty sampling at ''", diff --git a/tests/testthat/test-combine.R b/tests/testthat/test-combine.R index 00a68c09..6ac37a20 100644 --- a/tests/testthat/test-combine.R +++ b/tests/testthat/test-combine.R @@ -363,3 +363,28 @@ test_that("can split prior from likelihood", { monty_model_split(a), "Cannot split this model as it is not a combined model") }) + + +test_that("can combine restore functions", { + a <- monty_model(list(parameters = "x", + density = identity, + restore = function(x) message("a"))) + b <- monty_model(list(parameters = "x", + density = identity, + restore = function(x) message("b"))) + c <- monty_model(list(parameters = "x", + density = identity)) + d <- monty_model(list(parameters = "x", + density = identity)) + + ab <- monty_model_combine(a, b) + expect_equal(capture_messages(ab$restore()), c("a\n", "b\n")) + ba <- monty_model_combine(b, a) + expect_equal(capture_messages(ba$restore()), c("b\n", "a\n")) + ac <- monty_model_combine(a, c) + expect_equal(capture_messages(ac$restore()), "a\n") + ca <- monty_model_combine(c, a) + expect_equal(capture_messages(ca$restore()), "a\n") + cd <- monty_model_combine(c, d) + expect_equal(capture_messages(cd$restore()), character()) +}) diff --git a/tests/testthat/test-model-serialise.R b/tests/testthat/test-model-serialise.R new file mode 100644 index 00000000..2cc75405 --- /dev/null +++ b/tests/testthat/test-model-serialise.R @@ -0,0 +1,30 @@ +test_that("can repair a model", { + m <- ex_stochastic() + ll <- m$density(0) + + m2 <- unserialize(suppressWarnings(serialize(m, NULL))) + expect_error( + m2$density(0), + "Pointer has been serialised") + + m2$restore() + expect_no_error(m2$density(0)) +}) + + +test_that("can repair a combined model", { + m <- ex_stochastic() + p <- monty_dsl({ + x ~ Normal(0, 1) + }) + + m2 <- m + p + m2$density(0) + + m3 <- unserialize(suppressWarnings(serialize(m2, NULL))) + expect_error( + m3$density(0), + "Pointer has been serialised") + m3$restore() + expect_no_error(m3$density(0)) +}) From 87ebff95b102d91b4a53462fc7774657cc1325ac Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 12:35:23 +0000 Subject: [PATCH 4/7] Add support for setting rng state --- R/cpp11.R | 4 +++ R/rng.R | 11 ++++-- R/sample-manual.R | 1 + inst/include/monty/random/prng.hpp | 5 +++ man/monty_rng.Rd | 58 +++++++++++++++++++++++++++--- src/cpp11.cpp | 9 +++++ src/random.cpp | 22 ++++++++++++ tests/testthat/test-rng.R | 23 ++++++++++++ 8 files changed, 125 insertions(+), 8 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index 8271a5f8..fe6faff5 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -92,6 +92,10 @@ monty_rng_state <- function(ptr) { .Call(`_monty_monty_rng_state`, ptr) } +monty_rng_set_state <- function(ptr, r_state) { + invisible(.Call(`_monty_monty_rng_set_state`, ptr, r_state)) +} + cpp_monty_random_real <- function(ptr) { .Call(`_monty_cpp_monty_random_real`, ptr) } diff --git a/R/rng.R b/R/rng.R index 8b8f3726..dce76a27 100644 --- a/R/rng.R +++ b/R/rng.R @@ -464,10 +464,15 @@ monty_rng <- R6::R6Class( ##' @description ##' Returns the state of the random number stream. This returns a - ##' raw vector of length 32 * n_streams. It is primarily intended for - ##' debugging as one cannot (yet) initialise a monty_rng object with this - ##' state. + ##' raw vector of length 32 * `n_streams`. state = function() { monty_rng_state(private$ptr) + }, + + ##' @description + ##' Sets the state of the random number stream. + ##' @param state Raw vector of state, with length 32 * `n_streams`. + set_state = function(state) { + monty_rng_set_state(private$ptr, state) } )) diff --git a/R/sample-manual.R b/R/sample-manual.R index 2281293c..ca854a2d 100644 --- a/R/sample-manual.R +++ b/R/sample-manual.R @@ -116,6 +116,7 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) { path <- sample_manual_path(path, chain_id) inputs <- readRDS(path$inputs) + inputs$model$restore() if (chain_id > inputs$n_chains) { cli::cli_abort("'chain_id' must be an integer in 1..{inputs$n_chains}") } diff --git a/inst/include/monty/random/prng.hpp b/inst/include/monty/random/prng.hpp index 523055de..96d0ed90 100644 --- a/inst/include/monty/random/prng.hpp +++ b/inst/include/monty/random/prng.hpp @@ -81,6 +81,11 @@ class prng { return state_[i]; } + /// Return total number of elements in state + size_t state_size() const { + return size() * rng_state::size(); + } + /// Convert the random number state of all generators into a single /// vector. This can be used to save the state to restore using /// `import_state()` diff --git a/man/monty_rng.Rd b/man/monty_rng.Rd index 86f19426..f00ef473 100644 --- a/man/monty_rng.Rd +++ b/man/monty_rng.Rd @@ -13,9 +13,12 @@ numbers with the same RNG as monty uses internally. This is primarily meant for debugging and testing the underlying C++ rather than a source of random numbers from R. } -\section{Running multiple streams, perhaps in parallel}{ - +\section{Warning}{ +This interface is subject to change in the near future, we do not +recommend its use in user code. +} +\section{Running multiple streams, perhaps in parallel}{ The underlying random number generators are designed to work in parallel, and with random access to parameters (see \code{vignette("rng")} for more details). However, this is usually @@ -156,7 +159,9 @@ rng$multinomial(5, 10, c(0.1, 0.3, 0.5, 0.1)) \item \href{#method-monty_rng-cauchy}{\code{monty_rng$cauchy()}} \item \href{#method-monty_rng-multinomial}{\code{monty_rng$multinomial()}} \item \href{#method-monty_rng-beta}{\code{monty_rng$beta()}} +\item \href{#method-monty_rng-truncated_normal}{\code{monty_rng$truncated_normal()}} \item \href{#method-monty_rng-state}{\code{monty_rng$state()}} +\item \href{#method-monty_rng-set_state}{\code{monty_rng$set_state()}} } } \if{html}{\out{
}} @@ -646,6 +651,34 @@ Generate \code{n} numbers from a beta distribution \item{\code{b}}{The second shape parameter} +\item{\code{n_threads}}{Number of threads to use; see Details} +} +\if{html}{\out{}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-monty_rng-truncated_normal}{}}} +\subsection{Method \code{truncated_normal()}}{ +Generate \code{n} numbers from a truncated normal distribution +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{monty_rng$truncated_normal(n, mean, sd, min, max, n_threads = 1L)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{n}}{Number of samples to draw (per stream)} + +\item{\code{mean}}{The mean of the parent (untruncated) normal distribution} + +\item{\code{sd}}{The standard deviation of the parent (untruncated) +normal distribution.} + +\item{\code{min}}{The lower bound} + +\item{\code{max}}{The upper bound} + \item{\code{n_threads}}{Number of threads to use; see Details} } \if{html}{\out{
}} @@ -656,12 +689,27 @@ Generate \code{n} numbers from a beta distribution \if{latex}{\out{\hypertarget{method-monty_rng-state}{}}} \subsection{Method \code{state()}}{ Returns the state of the random number stream. This returns a -raw vector of length 32 * n_streams. It is primarily intended for -debugging as one cannot (yet) initialise a monty_rng object with this -state. +raw vector of length 32 * \code{n_streams}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{monty_rng$state()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-monty_rng-set_state}{}}} +\subsection{Method \code{set_state()}}{ +Sets the state of the random number stream. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{monty_rng$set_state(state)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{state}}{Raw vector of state, with length 32 * \code{n_streams}.} +} +\if{html}{\out{
}} +} } } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 4de8773a..325d10e8 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -168,6 +168,14 @@ extern "C" SEXP _monty_monty_rng_state(SEXP ptr) { return cpp11::as_sexp(monty_rng_state(cpp11::as_cpp>(ptr))); END_CPP11 } +// random.cpp +void monty_rng_set_state(SEXP ptr, cpp11::raws r_state); +extern "C" SEXP _monty_monty_rng_set_state(SEXP ptr, SEXP r_state) { + BEGIN_CPP11 + monty_rng_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state)); + return R_NilValue; + END_CPP11 +} // random2.cpp cpp11::doubles cpp_monty_random_real(SEXP ptr); extern "C" SEXP _monty_cpp_monty_random_real(SEXP ptr) { @@ -246,6 +254,7 @@ static const R_CallMethodDef CallEntries[] = { {"_monty_monty_rng_poisson", (DL_FUNC) &_monty_monty_rng_poisson, 4}, {"_monty_monty_rng_random_normal", (DL_FUNC) &_monty_monty_rng_random_normal, 4}, {"_monty_monty_rng_random_real", (DL_FUNC) &_monty_monty_rng_random_real, 3}, + {"_monty_monty_rng_set_state", (DL_FUNC) &_monty_monty_rng_set_state, 2}, {"_monty_monty_rng_state", (DL_FUNC) &_monty_monty_rng_state, 1}, {"_monty_monty_rng_truncated_normal", (DL_FUNC) &_monty_monty_rng_truncated_normal, 7}, {"_monty_monty_rng_uniform", (DL_FUNC) &_monty_monty_rng_uniform, 5}, diff --git a/src/random.cpp b/src/random.cpp index 14c92c92..7bb5ec3b 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -882,6 +882,23 @@ cpp11::sexp monty_rng_state(SEXP ptr) { return ret; } + +template +void monty_rng_set_state(SEXP ptr, cpp11::raws r_state) { + T *rng = safely_read_externalptr(ptr, "rng_set_state"); + + using int_type = typename T::int_type; + const auto len = rng->state_size() * sizeof(int_type); + if ((size_t)r_state.size() != len) { + cpp11::stop("'state' must be a raw vector of length %d (but was %d)", + len, r_state.size()); + } + std::vector state(len); + std::memcpy(state.data(), RAW(r_state), len); + rng->import_state(state); +} + + [[cpp11::register]] SEXP monty_rng_alloc(cpp11::sexp r_seed, int n_streams, bool deterministic) { return monty_rng_alloc(r_seed, n_streams, deterministic); @@ -1067,3 +1084,8 @@ cpp11::sexp monty_rng_truncated_normal(SEXP ptr, int n, cpp11::sexp monty_rng_state(SEXP ptr) { return monty_rng_state(ptr); } + +[[cpp11::register]] +void monty_rng_set_state(SEXP ptr, cpp11::raws r_state) { + return monty_rng_set_state(ptr, r_state); +} diff --git a/tests/testthat/test-rng.R b/tests/testthat/test-rng.R index 6ccb042e..abcec25a 100644 --- a/tests/testthat/test-rng.R +++ b/tests/testthat/test-rng.R @@ -1354,3 +1354,26 @@ test_that("can generate from truncated normal from lower tail", { }) expect_gt(sum(res > 0.05), 5) }) + + +test_that("can set state into rng", { + r1 <- monty_rng$new() + r2 <- monty_rng$new() + + s1 <- r1$state() + r2$set_state(s1) + + expect_equal(r2$random_real(10), + r1$random_real(10)) +}) + + +test_that("error if rng state wrong length", { + r <- monty_rng$new() + expect_error( + r$set_state(raw(4)), + "'state' must be a raw vector of length 32 (but was 4)") + + expect_equal(r2$random_real(10), + r1$random_real(10)) +}) From 1a2a708626c7ef28902fca359664bd34e6833c81 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 12:40:13 +0000 Subject: [PATCH 5/7] Restore for manual sample --- tests/testthat/helper-monty.R | 2 +- tests/testthat/test-sample-manual.R | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/testthat/helper-monty.R b/tests/testthat/helper-monty.R index 500d502a..fa7aac08 100644 --- a/tests/testthat/helper-monty.R +++ b/tests/testthat/helper-monty.R @@ -121,7 +121,7 @@ ex_stochastic <- function(n = 10, sd_sample = 1, sd_measure = 1) { env$rng$state() } set_rng_state <- function(rng_state) { - env$rng$state_set(rng_state) + env$rng$set_state(rng_state) } density <- function(x) { sum(dnorm(env$rng$normal(n, x, sd_sample), sd_measure, log = TRUE)) diff --git a/tests/testthat/test-sample-manual.R b/tests/testthat/test-sample-manual.R index 69c22c07..6fae9d68 100644 --- a/tests/testthat/test-sample-manual.R +++ b/tests/testthat/test-sample-manual.R @@ -324,3 +324,21 @@ test_that("can continue a manually run manual chain", { expect_equal(res1b, res2) }) + + +test_that("can sample from models requiring restore", { + path <- withr::local_tempdir() + model <- ex_stochastic() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.1) + + set.seed(1) + res1 <- monty_sample(model, sampler, 100, + n_chains = 1, initial = 0) + + set.seed(1) + monty_sample_manual_prepare(model, sampler, 100, path, + n_chains = 1, initial = 0) + monty_sample_manual_run(1, path) + res2 <- monty_sample_manual_collect(path) + expect_equal(res2, res1) +}) From d97e4cd79dffe9f0f892204fb0d385124955d67c Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 12:40:21 +0000 Subject: [PATCH 6/7] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index dd636dd8..22ff9ecc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.3.4 +Version: 0.3.5 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), From fb22022108ada0791f771d753a46837cb3b300c0 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 21 Nov 2024 13:19:13 +0000 Subject: [PATCH 7/7] Fix tests --- R/sample-manual.R | 7 ++++++- tests/testthat/test-rng.R | 6 ++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/R/sample-manual.R b/R/sample-manual.R index ca854a2d..c9225fdf 100644 --- a/R/sample-manual.R +++ b/R/sample-manual.R @@ -116,7 +116,6 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) { path <- sample_manual_path(path, chain_id) inputs <- readRDS(path$inputs) - inputs$model$restore() if (chain_id > inputs$n_chains) { cli::cli_abort("'chain_id' must be an integer in 1..{inputs$n_chains}") } @@ -127,6 +126,12 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) { restart <- inputs$restart is_continue <- is.list(restart) + if (is_continue) { + restart$model$restore() + } else { + inputs$model$restore() + } + pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE, single_chain = TRUE) diff --git a/tests/testthat/test-rng.R b/tests/testthat/test-rng.R index abcec25a..bea37115 100644 --- a/tests/testthat/test-rng.R +++ b/tests/testthat/test-rng.R @@ -1372,8 +1372,6 @@ test_that("error if rng state wrong length", { r <- monty_rng$new() expect_error( r$set_state(raw(4)), - "'state' must be a raw vector of length 32 (but was 4)") - - expect_equal(r2$random_real(10), - r1$random_real(10)) + "'state' must be a raw vector of length 32 (but was 4)", + fixed = TRUE) })