Skip to content

Commit

Permalink
Merge pull request #115 from mrc-ide/mrc-6055
Browse files Browse the repository at this point in the history
Support for restoring models after deserialisation
  • Loading branch information
weshinsley authored Nov 21, 2024
2 parents c037212 + fb22022 commit 4bf0373
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 31 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.3.4
Version: 0.3.5
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
13 changes: 13 additions & 0 deletions R/combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ monty_model_combine <- function(a, b, properties = NULL,
parts, properties)
observer <- model_combine_observer(
parts, parameters, properties, call)
restore <- model_combine_restore(
parts)

data <- list(
parts = unname(parts),
Expand All @@ -133,6 +135,7 @@ monty_model_combine <- function(a, b, properties = NULL,
gradient = gradient,
get_rng_state = stochastic$get_rng_state,
set_rng_state = stochastic$set_rng_state,
restore = restore,
observer = observer,
direct_sample = direct_sample),
properties)
Expand Down Expand Up @@ -455,3 +458,13 @@ model_combine_allow_multiple_parameters <- function(parts, properties,
"not supported by both of your models"),
call = call)
}


model_combine_restore <- function(parts) {
a <- parts[[1]]
b <- parts[[2]]
function() {
a$restore()
b$restore()
}
}
4 changes: 4 additions & 0 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ monty_model <- function(model, properties = NULL) {
observer <- validate_model_observer(model, properties, call)
rng_state <- validate_model_rng_state(model, properties, call)
parameter_groups <- validate_model_parameter_groups(model, properties, call)
restore <- validate_model_restore(model, properties, call)

## Update properties based on what we found:
properties$has_gradient <- !is.null(gradient)
Expand All @@ -217,6 +218,7 @@ monty_model <- function(model, properties = NULL) {
gradient = gradient,
direct_sample = direct_sample,
observer = observer,
restore = restore,
rng_state = rng_state,
properties = properties)
class(ret) <- "monty_model"
Expand Down Expand Up @@ -539,6 +541,16 @@ validate_model_parameter_groups <- function(model, properties, call) {
}


validate_model_restore <- function(model, propertioes, call) {
restore <- model$restore
if (is.function(restore)) {
restore
} else {
function() {}
}
}


require_monty_model <- function(model, arg = deparse(substitute(model)),
call = parent.frame()) {
if (!inherits(model, "monty_model")) {
Expand Down
11 changes: 8 additions & 3 deletions R/rng.R
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,15 @@ monty_rng <- R6::R6Class(

##' @description
##' Returns the state of the random number stream. This returns a
##' raw vector of length 32 * n_streams. It is primarily intended for
##' debugging as one cannot (yet) initialise a monty_rng object with this
##' state.
##' raw vector of length 32 * `n_streams`.
state = function() {
monty_rng_state(private$ptr)
},

##' @description
##' Sets the state of the random number stream.
##' @param state Raw vector of state, with length 32 * `n_streams`.
set_state = function(state) {
monty_rng_set_state(private$ptr, state)
}
))
6 changes: 6 additions & 0 deletions R/sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
restart <- inputs$restart
is_continue <- is.list(restart)

if (is_continue) {
restart$model$restore()
} else {
inputs$model$restore()
}

pb <- progress_bar(n_chains, steps$total, progress,
show_overall = FALSE, single_chain = TRUE)

Expand Down
5 changes: 5 additions & 0 deletions inst/include/monty/random/prng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class prng {
return state_[i];
}

/// Return total number of elements in state
size_t state_size() const {
return size() * rng_state::size();
}

/// Convert the random number state of all generators into a single
/// vector. This can be used to save the state to restore using
/// `import_state()`
Expand Down
58 changes: 53 additions & 5 deletions man/monty_rng.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/cpp11.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4bf0373

Please sign in to comment.