diff --git a/DESCRIPTION b/DESCRIPTION index 4671869..19d0b05 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -16,7 +16,7 @@ URL: https://github.com/ScottClaessens/coevolve BugReports: https://github.com/ScottClaessens/coevolve/issues Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Suggests: knitr, rmarkdown, @@ -38,6 +38,7 @@ Imports: purrr, readr, rlang, + stats, stringr, tidyr VignetteBuilder: knitr diff --git a/R/extract_samples.R b/R/extract_samples.R index 570a92e..21150bf 100644 --- a/R/extract_samples.R +++ b/R/extract_samples.R @@ -5,15 +5,23 @@ #' @return Samples in 'rethinking' style list format #' @export extract_samples <- function(object) { - UseMethod("extract_samples") + UseMethod("extract_samples") } #' @export extract_samples.coevfit <- function(object) { + # get variables and rvars draws vars <- object$fit$metadata()$stan_variables draws <- posterior::as_draws_rvars(object$fit$draws()) - - return(lapply(vars, \(var_name){ - posterior::draws_of(draws[[var_name]], with_chains = FALSE) - }) |> setNames(vars)) -} \ No newline at end of file + # reshape to 'rethinking' style list format + lapply( + vars, + \(var_name){ + posterior::draws_of( + draws[[var_name]], + with_chains = FALSE + ) + } + ) |> + stats::setNames(vars) +} diff --git a/tests/testthat/test-coev_fit.R b/tests/testthat/test-coev_fit.R index c617f89..aca22a5 100644 --- a/tests/testthat/test-coev_fit.R +++ b/tests/testthat/test-coev_fit.R @@ -527,6 +527,15 @@ test_that("coev_fit() fits the model without error", { # expect error if prob for summary is outside of range 0 - 1 expect_error(SW(summary(m1, prob = -0.01))) expect_error(SW(summary(m1, prob = 1.01))) + # expect no errors for extract_samples method + expect_no_error(SW(extract_samples(m1))) + expect_no_error(SW(extract_samples(m2))) + expect_no_error(SW(extract_samples(m3))) + expect_no_error(SW(extract_samples(m4))) + expect_true(SW(methods::is(extract_samples(m1), "list"))) + expect_true(SW(methods::is(extract_samples(m2), "list"))) + expect_true(SW(methods::is(extract_samples(m3), "list"))) + expect_true(SW(methods::is(extract_samples(m4), "list"))) }) test_that("effects_mat argument to coev_fit() works as expected", { @@ -540,6 +549,9 @@ test_that("effects_mat argument to coev_fit() works as expected", { expect_no_error(SW(summary(m))) expect_output(SW(print(m))) expect_output(SW(print(summary(m)))) + # expect no errors for extract_samples method + expect_no_error(SW(extract_samples(m))) + expect_true(SW(methods::is(extract_samples(m), "list"))) # expect effects_mat correct in model output effects_mat <- matrix( c(TRUE, TRUE, @@ -571,6 +583,9 @@ test_that("coev_fit() works with missing data", { expect_no_error(SW(summary(m))) expect_output(SW(print(m))) expect_output(SW(print(summary(m)))) + # expect no errors for extract_samples method + expect_no_error(SW(extract_samples(m))) + expect_true(SW(methods::is(extract_samples(m), "list"))) # expect warning in summary output capture.output( SW(