Skip to content

Commit

Permalink
Add tests for extract_samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott Claessens authored and Scott Claessens committed Aug 8, 2024
1 parent fd0e063 commit 8cc0b6a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ URL: https://github.com/ScottClaessens/coevolve
BugReports: https://github.com/ScottClaessens/coevolve/issues
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Suggests:
knitr,
rmarkdown,
Expand All @@ -38,6 +38,7 @@ Imports:
purrr,
readr,
rlang,
stats,
stringr,
tidyr
VignetteBuilder: knitr
Expand Down
20 changes: 14 additions & 6 deletions R/extract_samples.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@
#' @return Samples in 'rethinking' style list format
#' @export
extract_samples <- function(object) {
UseMethod("extract_samples")
UseMethod("extract_samples")
}

#' @export
extract_samples.coevfit <- function(object) {
# get variables and rvars draws
vars <- object$fit$metadata()$stan_variables
draws <- posterior::as_draws_rvars(object$fit$draws())

return(lapply(vars, \(var_name){
posterior::draws_of(draws[[var_name]], with_chains = FALSE)
}) |> setNames(vars))
}
# reshape to 'rethinking' style list format
lapply(
vars,
\(var_name){
posterior::draws_of(
draws[[var_name]],
with_chains = FALSE
)
}
) |>
stats::setNames(vars)
}
15 changes: 15 additions & 0 deletions tests/testthat/test-coev_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,15 @@ test_that("coev_fit() fits the model without error", {
# expect error if prob for summary is outside of range 0 - 1
expect_error(SW(summary(m1, prob = -0.01)))
expect_error(SW(summary(m1, prob = 1.01)))
# expect no errors for extract_samples method
expect_no_error(SW(extract_samples(m1)))
expect_no_error(SW(extract_samples(m2)))
expect_no_error(SW(extract_samples(m3)))
expect_no_error(SW(extract_samples(m4)))
expect_true(SW(methods::is(extract_samples(m1), "list")))
expect_true(SW(methods::is(extract_samples(m2), "list")))
expect_true(SW(methods::is(extract_samples(m3), "list")))
expect_true(SW(methods::is(extract_samples(m4), "list")))
})

test_that("effects_mat argument to coev_fit() works as expected", {
Expand All @@ -540,6 +549,9 @@ test_that("effects_mat argument to coev_fit() works as expected", {
expect_no_error(SW(summary(m)))
expect_output(SW(print(m)))
expect_output(SW(print(summary(m))))
# expect no errors for extract_samples method
expect_no_error(SW(extract_samples(m)))
expect_true(SW(methods::is(extract_samples(m), "list")))
# expect effects_mat correct in model output
effects_mat <- matrix(
c(TRUE, TRUE,
Expand Down Expand Up @@ -571,6 +583,9 @@ test_that("coev_fit() works with missing data", {
expect_no_error(SW(summary(m)))
expect_output(SW(print(m)))
expect_output(SW(print(summary(m))))
# expect no errors for extract_samples method
expect_no_error(SW(extract_samples(m)))
expect_true(SW(methods::is(extract_samples(m), "list")))
# expect warning in summary output
capture.output(
SW(
Expand Down

0 comments on commit 8cc0b6a

Please sign in to comment.