Skip to content

Commit

Permalink
Merge pull request #139 from mrc-ide/mrc-6058
Browse files Browse the repository at this point in the history
Get coefficients from likelihood objects
  • Loading branch information
weshinsley authored Nov 22, 2024
2 parents 4f02d7c + 76d3506 commit 0313d68
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.3.7
Version: 0.3.8
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ S3method("[<-",dust_system_generator)
S3method("[[<-",dust_likelihood)
S3method("[[<-",dust_system)
S3method("[[<-",dust_system_generator)
S3method(coef,dust_likelihood)
S3method(coef,dust_system)
S3method(coef,dust_system_generator)
S3method(dim,dust_system)
Expand Down Expand Up @@ -50,4 +51,5 @@ export(dust_system_update_pars)
export(dust_unfilter_create)
export(dust_unpack_index)
export(dust_unpack_state)
importFrom(stats,coef)
useDynLib(dust2, .registration = TRUE)
8 changes: 8 additions & 0 deletions R/interface-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,13 @@ print.dust_likelihood <- function(x, ...) {
time_type <- attr(x$generator, "properties")$time_type
cli::cli_bullets(
c(i = describe_time(time_type, NULL, x$time_control$dt)))
cli::cli_alert_info(
"Use {.help [coef()](stats::coef)} to get more information on parameters")
invisible(x)
}


##' @export
coef.dust_likelihood <- function(object, ...) {
coef(object$generator)
}
5 changes: 3 additions & 2 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ describe_time <- function(time_type, default_dt, dt) {
}
}

##' @importFrom stats coef
##' @export
coef.dust_system_generator <- function(object, ...) {
attr(object, "parameters")
Expand All @@ -635,8 +636,8 @@ dim.dust_system <- function(x, ...) {
}

##' @export
coef.dust_system <- function(x, ...) {
x$parameters
coef.dust_system <- function(object, ...) {
object$parameters
}


Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,16 @@ test_that("prevent access to the filter after serialisation", {
"Pointer has been serialised, cannot continue safely (filter_rng_state)",
fixed = TRUE)
})


test_that("can get coefficients from filter", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
time_start <- 0
data <- data.frame(time = c(4, 8, 12, 16), incidence = 1:4)
dt <- 1
n_particles <- 100
seed <- 42
obj <- dust_filter_create(sir(), time_start, data, dt = dt,
n_particles = n_particles, seed = seed)
expect_equal(coef(obj), coef(sir))
})

0 comments on commit 0313d68

Please sign in to comment.